Adding some slightly more exhaustive strided_slice test parameters.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 26 Apr 2018 20:35:35 +0000 (13:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 26 Apr 2018 20:37:55 +0000 (13:37 -0700)
PiperOrigin-RevId: 194446000

tensorflow/contrib/lite/kernels/internal/BUILD
tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h [new file with mode: 0644]
tensorflow/contrib/lite/testing/generate_examples.py
tensorflow/contrib/lite/toco/BUILD
tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc

index 67dd188..dce14cd 100644 (file)
@@ -155,6 +155,7 @@ cc_library(
     copts = tflite_copts(),
     deps = [
         ":quantization_util",
+        ":strided_slice_logic",
         ":types",
         ":round",
         "//third_party/eigen3",
@@ -230,6 +231,17 @@ cc_test(
 )
 
 cc_library(
+    name = "strided_slice_logic",
+    srcs = [],
+    hdrs = [
+        "strided_slice_logic.h",
+    ],
+    deps = [
+        ":types",
+    ],
+)
+
+cc_library(
     name = "reference_base",
     srcs = [],
     hdrs = [
@@ -241,6 +253,7 @@ cc_library(
     deps = [
         ":quantization_util",
         ":round",
+        ":strided_slice_logic",
         ":types",
         "//third_party/eigen3",
         "@gemmlowp",
index 9e9aba0..3d6042c 100644 (file)
@@ -32,6 +32,7 @@ limitations under the License.
 #include "tensorflow/contrib/lite/kernels/internal/common.h"
 #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
 #include "tensorflow/contrib/lite/kernels/internal/types.h"
 
 namespace tflite {
@@ -5864,90 +5865,7 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
       output_dims, 0);
 }
 
-// UNOPTIMIZED COPY of StridedSlice from reference_ops.h (see comments there).
-
-// Use until std::clamp() is available from C++17.
-inline int Clamp(const int v, const int lo, const int hi) {
-  TFLITE_DCHECK(!(hi < lo));
-  if (hi < v) return hi;
-  if (v < lo) return lo;
-  return v;
-}
-
-inline int StartForAxis(int begin_mask, const std::vector<int>& start_indices,
-                        const std::vector<int>& strides,
-                        const Dims<4>& input_shape, int axis) {
-  // Begin with the specified index
-  int start = start_indices[axis];
-
-  // begin_mask override
-  if (begin_mask & 1 << axis) {
-    if (strides[axis] > 0) {
-      // Forward iteration - use the first element. These values will get
-      // clamped below (Note: We could have set them to 0 and axis_size-1, but
-      // use lowest() and max() to maintain symmetry with StopForAxis())
-      start = std::numeric_limits<int>::lowest();
-    } else {
-      // Backward iteration - use the last element.
-      start = std::numeric_limits<int>::max();
-    }
-  }
-
-  // Handle negative indices
-  int axis_size = input_shape.sizes[axis];
-  if (start < 0) {
-    start += axis_size;
-  }
-
-  // Clamping
-  start = Clamp(start, 0, axis_size - 1);
-
-  return start;
-}
-
-inline int StopForAxis(int end_mask, const std::vector<int>& stop_indices,
-                       const std::vector<int>& strides,
-                       const Dims<4>& input_shape, int axis) {
-  // Begin with the specified index
-  int stop = stop_indices[axis];
-
-  // end_mask override
-  if (end_mask & (1 << axis)) {
-    if (strides[axis] > 0) {
-      // Forward iteration - use the last element. These values will get
-      // clamped below
-      stop = std::numeric_limits<int>::max();
-    } else {
-      // Backward iteration - use the first element.
-      stop = std::numeric_limits<int>::lowest();
-    }
-  }
-
-  // Handle negative indices
-  int axis_size = input_shape.sizes[axis];
-  if (stop < 0) {
-    stop += axis_size;
-  }
-
-  // Clamping
-  // Because the end index points one past the last element, we need slightly
-  // different clamping ranges depending on the direction.
-  if (strides[axis] > 0) {
-    // Forward iteration
-    stop = Clamp(stop, 0, axis_size);
-  } else {
-    // Backward iteration
-    stop = Clamp(stop, -1, axis_size - 1);
-  }
-
-  return stop;
-}
-
-inline bool LoopCondition(int index, int stop, int stride) {
-  // True when we have reached the end of an axis and should loop.
-  return stride > 0 ? index >= stop : index <= stop;
-}
-
+// UNOPTIMIZED COPY of StridedSlice from reference_ops.h.
 template <typename T>
 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
                          int begin_mask, int end_mask,
@@ -5958,31 +5876,35 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
   TFLITE_DCHECK_EQ(start_indices.size(), 4);
   TFLITE_DCHECK_EQ(stop_indices.size(), 4);
   TFLITE_DCHECK_EQ(strides.size(), 4);
-  const int start_b =
-      StartForAxis(begin_mask, start_indices, strides, input_dims, 3);
-  const int stop_b =
-      StopForAxis(end_mask, stop_indices, strides, input_dims, 3);
-  const int start_h =
-      StartForAxis(begin_mask, start_indices, strides, input_dims, 2);
-  const int stop_h =
-      StopForAxis(end_mask, stop_indices, strides, input_dims, 2);
-  const int start_w =
-      StartForAxis(begin_mask, start_indices, strides, input_dims, 1);
-  const int stop_w =
-      StopForAxis(end_mask, stop_indices, strides, input_dims, 1);
-  const int start_d =
-      StartForAxis(begin_mask, start_indices, strides, input_dims, 0);
-  const int stop_d =
-      StopForAxis(end_mask, stop_indices, strides, input_dims, 0);
+  const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
+                                                  strides, input_dims.sizes, 3);
+  const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+                                                input_dims.sizes, 3);
+  const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
+                                                  strides, input_dims.sizes, 2);
+  const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+                                                input_dims.sizes, 2);
+  const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
+                                                  strides, input_dims.sizes, 1);
+  const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+                                                input_dims.sizes, 1);
+  const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
+                                                  strides, input_dims.sizes, 0);
+  const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+                                                input_dims.sizes, 0);
 
   T* out_ptr = output_data;
-  for (int in_b = start_b; !LoopCondition(in_b, stop_b, strides[3]);
+  for (int in_b = start_b;
+       !strided_slice::LoopCondition(in_b, stop_b, strides[3]);
        in_b += strides[3]) {
-    for (int in_h = start_h; !LoopCondition(in_h, stop_h, strides[2]);
+    for (int in_h = start_h;
+         !strided_slice::LoopCondition(in_h, stop_h, strides[2]);
          in_h += strides[2]) {
-      for (int in_w = start_w; !LoopCondition(in_w, stop_w, strides[1]);
+      for (int in_w = start_w;
+           !strided_slice::LoopCondition(in_w, stop_w, strides[1]);
            in_w += strides[1]) {
-        for (int in_d = start_d; !LoopCondition(in_d, stop_d, strides[0]);
+        for (int in_d = start_d;
+             !strided_slice::LoopCondition(in_d, stop_d, strides[0]);
              in_d += strides[0]) {
           *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
         }
index 4c8cbe4..d41ade4 100644 (file)
@@ -29,6 +29,7 @@ limitations under the License.
 #include "tensorflow/contrib/lite/kernels/internal/common.h"
 #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
 #include "tensorflow/contrib/lite/kernels/internal/types.h"
 
 namespace tflite {
@@ -3131,104 +3132,6 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
       output_dims, 0);
 }
 
-// STRIDED SLICE
-// The functions below for StridedSlice are mirrored in a number of places:
-//
-//   propagate_fixed_sizes.cc
-//   propagate_shapes.cc
-//   resolve_constant_strided_slice.cc
-//   optimized_ops.h
-//
-// It is designed for an arbitrary number of dimensions, even though dimensions
-// here are fixed at 4. This is because we expect to eventually support
-// arbitrary dimensionality. Also note that the axis orders are reversed for
-// runtime ops, and so the indices and masks must be as well too.
-//
-// Be warned this code involves some rather subtle logic of python slicing. The
-// best "ground truth" is to compare results to actual python execution.
-
-// Use until std::clamp() is available from C++17.
-inline int Clamp(const int v, const int lo, const int hi) {
-  TFLITE_DCHECK(!(hi < lo));
-  if (hi < v) return hi;
-  if (v < lo) return lo;
-  return v;
-}
-
-inline int StartForAxis(int begin_mask, const std::vector<int>& start_indices,
-                        const std::vector<int>& strides,
-                        const Dims<4>& input_shape, int axis) {
-  // Begin with the specified index
-  int start = start_indices[axis];
-
-  // begin_mask override
-  if (begin_mask & 1 << axis) {
-    if (strides[axis] > 0) {
-      // Forward iteration - use the first element. These values will get
-      // clamped below (Note: We could have set them to 0 and axis_size-1, but
-      // use lowest() and max() to maintain symmetry with StopForAxis())
-      start = std::numeric_limits<int>::lowest();
-    } else {
-      // Backward iteration - use the last element.
-      start = std::numeric_limits<int>::max();
-    }
-  }
-
-  // Handle negative indices
-  int axis_size = input_shape.sizes[axis];
-  if (start < 0) {
-    start += axis_size;
-  }
-
-  // Clamping
-  start = Clamp(start, 0, axis_size - 1);
-
-  return start;
-}
-
-inline int StopForAxis(int end_mask, const std::vector<int>& stop_indices,
-                       const std::vector<int>& strides,
-                       const Dims<4>& input_shape, int axis) {
-  // Begin with the specified index
-  int stop = stop_indices[axis];
-
-  // end_mask override
-  if (end_mask & (1 << axis)) {
-    if (strides[axis] > 0) {
-      // Forward iteration - use the last element. These values will get
-      // clamped below
-      stop = std::numeric_limits<int>::max();
-    } else {
-      // Backward iteration - use the first element.
-      stop = std::numeric_limits<int>::lowest();
-    }
-  }
-
-  // Handle negative indices
-  int axis_size = input_shape.sizes[axis];
-  if (stop < 0) {
-    stop += axis_size;
-  }
-
-  // Clamping
-  // Because the end index points one past the last element, we need slightly
-  // different clamping ranges depending on the direction.
-  if (strides[axis] > 0) {
-    // Forward iteration
-    stop = Clamp(stop, 0, axis_size);
-  } else {
-    // Backward iteration
-    stop = Clamp(stop, -1, axis_size - 1);
-  }
-
-  return stop;
-}
-
-inline bool LoopCondition(int index, int stop, int stride) {
-  // True when we have reached the end of an axis and should loop.
-  return stride > 0 ? index >= stop : index <= stop;
-}
-
 template <typename T>
 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
                          int begin_mask, int end_mask,
@@ -3236,34 +3139,40 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
                          const std::vector<int>& stop_indices,
                          const std::vector<int>& strides, T* output_data,
                          const Dims<4>& output_dims) {
+  // Note that the axis orders are reversed for runtime ops, so the indices,
+  // strides and masks must be as well too.
   TFLITE_DCHECK_EQ(start_indices.size(), 4);
   TFLITE_DCHECK_EQ(stop_indices.size(), 4);
   TFLITE_DCHECK_EQ(strides.size(), 4);
-  const int start_b =
-      StartForAxis(begin_mask, start_indices, strides, input_dims, 3);
-  const int stop_b =
-      StopForAxis(end_mask, stop_indices, strides, input_dims, 3);
-  const int start_h =
-      StartForAxis(begin_mask, start_indices, strides, input_dims, 2);
-  const int stop_h =
-      StopForAxis(end_mask, stop_indices, strides, input_dims, 2);
-  const int start_w =
-      StartForAxis(begin_mask, start_indices, strides, input_dims, 1);
-  const int stop_w =
-      StopForAxis(end_mask, stop_indices, strides, input_dims, 1);
-  const int start_d =
-      StartForAxis(begin_mask, start_indices, strides, input_dims, 0);
-  const int stop_d =
-      StopForAxis(end_mask, stop_indices, strides, input_dims, 0);
+  const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
+                                                  strides, input_dims.sizes, 3);
+  const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+                                                input_dims.sizes, 3);
+  const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
+                                                  strides, input_dims.sizes, 2);
+  const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+                                                input_dims.sizes, 2);
+  const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
+                                                  strides, input_dims.sizes, 1);
+  const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+                                                input_dims.sizes, 1);
+  const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
+                                                  strides, input_dims.sizes, 0);
+  const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+                                                input_dims.sizes, 0);
 
   T* out_ptr = output_data;
-  for (int in_b = start_b; !LoopCondition(in_b, stop_b, strides[3]);
+  for (int in_b = start_b;
+       !strided_slice::LoopCondition(in_b, stop_b, strides[3]);
        in_b += strides[3]) {
-    for (int in_h = start_h; !LoopCondition(in_h, stop_h, strides[2]);
+    for (int in_h = start_h;
+         !strided_slice::LoopCondition(in_h, stop_h, strides[2]);
          in_h += strides[2]) {
-      for (int in_w = start_w; !LoopCondition(in_w, stop_w, strides[1]);
+      for (int in_w = start_w;
+           !strided_slice::LoopCondition(in_w, stop_w, strides[1]);
            in_w += strides[1]) {
-        for (int in_d = start_d; !LoopCondition(in_d, stop_d, strides[0]);
+        for (int in_d = start_d;
+             !strided_slice::LoopCondition(in_d, stop_d, strides[0]);
              in_d += strides[0]) {
           *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
         }
diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
new file mode 100644 (file)
index 0000000..ef77371
--- /dev/null
@@ -0,0 +1,124 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
+
+#include <limits>
+#include <vector>
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+
+namespace tflite {
+
+namespace strided_slice {
+
+// Use until std::clamp() is available from C++17.
+inline int Clamp(const int v, const int lo, const int hi) {
+  TFLITE_DCHECK(!(hi < lo));
+  if (hi < v) return hi;
+  if (v < lo) return lo;
+  return v;
+}
+
+// Return the index for the first element along that axis. This index will be a
+// positive integer between [0, axis_size - 1] that can be used to index
+// directly into the data.
+template <typename IntType>
+inline int StartForAxis(int begin_mask,
+                        std::vector<IntType> const& start_indices,
+                        std::vector<IntType> const& strides,
+                        int const* input_shape, int axis) {
+  // Begin with the specified index
+  int start = start_indices[axis];
+
+  // begin_mask override
+  if (begin_mask & 1 << axis) {
+    if (strides[axis] > 0) {
+      // Forward iteration - use the first element. These values will get
+      // clamped below (Note: We could have set them to 0 and axis_size-1, but
+      // use lowest() and max() to maintain symmetry with StopForAxis())
+      start = std::numeric_limits<int>::lowest();
+    } else {
+      // Backward iteration - use the last element.
+      start = std::numeric_limits<int>::max();
+    }
+  }
+
+  // Handle negative indices
+  int axis_size = input_shape[axis];
+  if (start < 0) {
+    start += axis_size;
+  }
+
+  // Clamping
+  start = Clamp(start, 0, axis_size - 1);
+
+  return start;
+}
+
+// Return the "real" index for the end of iteration along that axis. This is an
+// "end" in the traditional C sense, in that it points to one past the last
+// element. ie. So if you were iterating through all elements of a 1D array of
+// size 4, this function would return 4 as the stop, because it is one past the
+// "real" indices of 0, 1, 2 & 3.
+template <typename IntType>
+inline int StopForAxis(int end_mask, std::vector<IntType> const& stop_indices,
+                       std::vector<IntType> const& strides,
+                       int const* input_shape, int axis) {
+  // Begin with the specified index
+  int stop = stop_indices[axis];
+
+  // end_mask override
+  if (end_mask & (1 << axis)) {
+    if (strides[axis] > 0) {
+      // Forward iteration - use the last element. These values will get
+      // clamped below
+      stop = std::numeric_limits<int>::max();
+    } else {
+      // Backward iteration - use the first element.
+      stop = std::numeric_limits<int>::lowest();
+    }
+  }
+
+  // Handle negative indices
+  int axis_size = input_shape[axis];
+  if (stop < 0) {
+    stop += axis_size;
+  }
+
+  // Clamping
+  // Because the end index points one past the last element, we need slightly
+  // different clamping ranges depending on the direction.
+  if (strides[axis] > 0) {
+    // Forward iteration
+    stop = Clamp(stop, 0, axis_size);
+  } else {
+    // Backward iteration
+    stop = Clamp(stop, -1, axis_size - 1);
+  }
+
+  return stop;
+}
+
+inline bool LoopCondition(int index, int stop, int stride) {
+  // True when we have reached the end of an axis and should loop.
+  return stride > 0 ? index >= stop : index <= stop;
+}
+
+}  // namespace strided_slice
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
index f72a4e0..9c9acf6 100644 (file)
@@ -1772,6 +1772,19 @@ def make_strided_slice_tests(zip_path):
           "shrink_axis_mask": [None, 1, 2, 3, -1],
           "constant_indices": [False, True],
       },
+      # 1-D Exhaustive
+      {
+          "dtype": [tf.float32],
+          "index_type": [tf.int32],
+          "input_shape": [[4]],
+          "begin": [[-100], [-3], [-2], [-1], [0], [1], [2], [3], [100]],
+          "end": [[-100], [-3], [-2], [-1], [0], [1], [2], [3], [100]],
+          "strides": [-2, -1, 1, 2],
+          "begin_mask": [0, 1],
+          "end_mask": [0, 1],
+          "shrink_axis_mask": [0],
+          "constant_indices": [False],
+      },
       # Negative strides
       {
           "dtype": [tf.float32],
index 3f73ef6..f92e546 100644 (file)
@@ -308,6 +308,7 @@ cc_library(
         ":toco_port",
         ":tooling_util",
         "//tensorflow/contrib/lite/kernels/internal:quantization_util",
+        "//tensorflow/contrib/lite/kernels/internal:strided_slice_logic",
         "//tensorflow/core:lib",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
index be6e0e0..19037bc 100644 (file)
@@ -20,6 +20,7 @@ limitations under the License.
 #include <vector>
 
 #include "absl/strings/str_join.h"
+#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/contrib/lite/toco/model.h"
 #include "tensorflow/contrib/lite/toco/tooling_util.h"
@@ -1235,83 +1236,6 @@ void ProcessStackOperator(Model* model, StackOperator* op) {
   output_array.copy_shape(*stacked_shape);
 }
 
-// These StridedSlice utility functions are essentially a COPY of those in
-// reference_ops.h. See comments there.
-
-// Use until std::clamp() is available from C++17.
-int Clamp(const int v, const int lo, const int hi) {
-  if (hi < v) return hi;
-  if (v < lo) return lo;
-  return v;
-}
-
-int StartForAxis(StridedSliceOperator const& op, Shape const& input_shape,
-                 int axis) {
-  // Begin with the specified index
-  int start = op.start_indices[axis];
-
-  // begin_mask override
-  if (op.begin_mask & 1 << axis) {
-    if (op.strides[axis] > 0) {
-      // Forward iteration - use the first element. These values will get
-      // clamped below (Note: We could have set them to 0 and axis_size-1, but
-      // use lowest() and max() to maintain symmetry with StopForAxis())
-      start = std::numeric_limits<int>::lowest();
-    } else {
-      // Backward iteration - use the last element.
-      start = std::numeric_limits<int>::max();
-    }
-  }
-
-  // Handle negative indices
-  int axis_size = input_shape.dims(axis);
-  if (start < 0) {
-    start += axis_size;
-  }
-
-  // Clamping
-  start = Clamp(start, 0, axis_size - 1);
-
-  return start;
-}
-
-int StopForAxis(StridedSliceOperator const& op, Shape const& input_shape,
-                int axis) {
-  // Begin with the specified index
-  int stop = op.stop_indices[axis];
-
-  // end_mask override
-  if (op.end_mask & (1 << axis)) {
-    if (op.strides[axis] > 0) {
-      // Forward iteration - use the last element. These values will get
-      // clamped below
-      stop = std::numeric_limits<int>::max();
-    } else {
-      // Backward iteration - use the first element.
-      stop = std::numeric_limits<int>::lowest();
-    }
-  }
-
-  // Handle negative indices
-  int axis_size = input_shape.dims(axis);
-  if (stop < 0) {
-    stop += axis_size;
-  }
-
-  // Clamping
-  // Because the end index points one past the last element, we need slightly
-  // different clamping ranges depending on the direction.
-  if (op.strides[axis] > 0) {
-    // Forward iteration
-    stop = Clamp(stop, 0, axis_size);
-  } else {
-    // Backward iteration
-    stop = Clamp(stop, -1, axis_size - 1);
-  }
-
-  return stop;
-}
-
 void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
   CHECK_GE(op->inputs.size(), 1);
   CHECK_EQ(op->outputs.size(), 1);
@@ -1364,18 +1288,17 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
                                 << " has stride=" << op->strides[i] << ".";
   }
 
-  // The TensorFlow documentation is not explicit on how it handles fewer
-  // supplied indices than dimensions, but they are accepted. We emulate TF's
-  // behavior by fully iterating over each "forgotten" dimension.
-  op->PadIndices(num_input_axes);
-
   // Create output shape
   std::vector<int>* dims = output_array.mutable_shape()->mutable_dims();
 
   // Compute output shape
   for (int axis = 0; axis < num_input_axes; ++axis) {
-    int start_index = StartForAxis(*op, input_array.shape(), axis);
-    int stop_index = StopForAxis(*op, input_array.shape(), axis);
+    int start_index = tflite::strided_slice::StartForAxis(
+        op->begin_mask, op->start_indices, op->strides,
+        input_array.shape().dims().data(), axis);
+    int stop_index = tflite::strided_slice::StopForAxis(
+        op->end_mask, op->stop_indices, op->strides,
+        input_array.shape().dims().data(), axis);
     int dim_size =
         ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]);
 
index 8df3c2f..1dd52e9 100644 (file)
@@ -14,6 +14,7 @@ limitations under the License.
 ==============================================================================*/
 #include <vector>
 
+#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
 #include "tensorflow/contrib/lite/toco/model.h"
 #include "tensorflow/contrib/lite/toco/tooling_util.h"
@@ -23,88 +24,6 @@ namespace toco {
 
 namespace {
 
-// These StridedSlice utility functions are essentially a COPY of those in
-// reference_ops.h. See comments there.
-
-// Use until std::clamp() is available from C++17.
-int Clamp(const int v, const int lo, const int hi) {
-  if (hi < v) return hi;
-  if (v < lo) return lo;
-  return v;
-}
-
-int StartForAxis(StridedSliceOperator const& op, Shape const& input_shape,
-                 int axis) {
-  // Begin with the specified index
-  int start = op.start_indices[axis];
-
-  // begin_mask override
-  if (op.begin_mask & 1 << axis) {
-    if (op.strides[axis] > 0) {
-      // Forward iteration - use the first element. These values will get
-      // clamped below (Note: We could have set them to 0 and axis_size-1, but
-      // use lowest() and max() to maintain symmetry with StopForAxis())
-      start = std::numeric_limits<int>::lowest();
-    } else {
-      // Backward iteration - use the last element.
-      start = std::numeric_limits<int>::max();
-    }
-  }
-
-  // Handle negative indices
-  int axis_size = input_shape.dims(axis);
-  if (start < 0) {
-    start += axis_size;
-  }
-
-  // Clamping
-  start = Clamp(start, 0, axis_size - 1);
-
-  return start;
-}
-
-int StopForAxis(StridedSliceOperator const& op, Shape const& input_shape,
-                int axis) {
-  // Begin with the specified index
-  int stop = op.stop_indices[axis];
-
-  // end_mask override
-  if (op.end_mask & (1 << axis)) {
-    if (op.strides[axis] > 0) {
-      // Forward iteration - use the last element. These values will get
-      // clamped below
-      stop = std::numeric_limits<int>::max();
-    } else {
-      // Backward iteration - use the first element.
-      stop = std::numeric_limits<int>::lowest();
-    }
-  }
-
-  // Handle negative indices
-  int axis_size = input_shape.dims(axis);
-  if (stop < 0) {
-    stop += axis_size;
-  }
-
-  // Clamping
-  // Because the end index points one past the last element, we need slightly
-  // different clamping ranges depending on the direction.
-  if (op.strides[axis] > 0) {
-    // Forward iteration
-    stop = Clamp(stop, 0, axis_size);
-  } else {
-    // Backward iteration
-    stop = Clamp(stop, -1, axis_size - 1);
-  }
-
-  return stop;
-}
-
-bool LoopCondition(int index, int stop, int stride) {
-  // True when we have reached the end of an axis and should loop.
-  return stride > 0 ? index >= stop : index <= stop;
-}
-
 template <ArrayDataType Type>
 void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
                   Array* output_array) {
@@ -132,7 +51,9 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
   Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>();
   std::vector<int> src_coord(op.start_indices.size());
   for (int axis = 0; axis < num_input_axes; axis++) {
-    src_coord[axis] = StartForAxis(op, input_shape, axis);
+    src_coord[axis] = tflite::strided_slice::StartForAxis(
+        op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(),
+        axis);
   }
 
   // In order to handle any number (N) of dimensions, we copy elements one by
@@ -155,10 +76,14 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
       }
 
       // Check if we've overflowed.
-      int stop = StopForAxis(op, input_shape, axis);
-      if (LoopCondition(src_coord[axis], stop, stride)) {
+      int stop = tflite::strided_slice::StopForAxis(
+          op.end_mask, op.stop_indices, op.strides, input_shape.dims().data(),
+          axis);
+      if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) {
         // Reset axis and set carry
-        src_coord[axis] = StartForAxis(op, input_shape, axis);
+        src_coord[axis] = tflite::strided_slice::StartForAxis(
+            op.begin_mask, op.start_indices, op.strides,
+            input_shape.dims().data(), axis);
         carry = true;
       } else {
         carry = false;
index 7e8b249..021e991 100644 (file)
@@ -31,6 +31,12 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
   }
 
   CHECK_EQ(op->inputs.size(), 4);
+  const auto& input_array = model->GetArray(op->inputs[0]);
+  if (!input_array.has_shape()) {
+    // We require the dimensionality of the input to pad the indices
+    return false;
+  }
+
   const auto& start_array = model->GetArray(op->inputs[1]);
   if (!start_array.has_shape()) return false;
   if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) {
@@ -57,6 +63,21 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
   CHECK_EQ(op->stop_indices.size(), op->start_indices.size());
   CHECK_EQ(op->strides.size(), op->stop_indices.size());
 
+  // The TensorFlow documentation is not explicit on how it handles fewer
+  // supplied indices than dimensions, but they are accepted. We emulate TF's
+  // behavior by fully iterating over each omitted dimension.
+  int num_input_axes = input_array.shape().dimensions_count();
+  CHECK_LE(op->start_indices.size(), num_input_axes)
+      << "StridedSlice op requires no more than " << num_input_axes
+      << " start indices";
+  CHECK_LE(op->stop_indices.size(), num_input_axes)
+      << "StridedSlice op requires no more than " << num_input_axes
+      << " stop indices";
+  CHECK_LE(op->strides.size(), num_input_axes)
+      << "StridedSlice op requires no more than " << num_input_axes
+      << " strides";
+  op->PadIndices(num_input_axes);
+
   // Ideally, we would remove the input arrays after they have been resolved.
   // However, we must then reconstitute these input arrays for all supported
   // export formats. For now, leave the arrays so we don't have to modify our