copts = tflite_copts(),
deps = [
":quantization_util",
+ ":strided_slice_logic",
":types",
":round",
"//third_party/eigen3",
)
cc_library(
+ name = "strided_slice_logic",
+ srcs = [],
+ hdrs = [
+ "strided_slice_logic.h",
+ ],
+ deps = [
+ ":types",
+ ],
+)
+
+cc_library(
name = "reference_base",
srcs = [],
hdrs = [
deps = [
":quantization_util",
":round",
+ ":strided_slice_logic",
":types",
"//third_party/eigen3",
"@gemmlowp",
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
output_dims, 0);
}
-// UNOPTIMIZED COPY of StridedSlice from reference_ops.h (see comments there).
-
-// Use until std::clamp() is available from C++17.
-inline int Clamp(const int v, const int lo, const int hi) {
- TFLITE_DCHECK(!(hi < lo));
- if (hi < v) return hi;
- if (v < lo) return lo;
- return v;
-}
-
-inline int StartForAxis(int begin_mask, const std::vector<int>& start_indices,
- const std::vector<int>& strides,
- const Dims<4>& input_shape, int axis) {
- // Begin with the specified index
- int start = start_indices[axis];
-
- // begin_mask override
- if (begin_mask & 1 << axis) {
- if (strides[axis] > 0) {
- // Forward iteration - use the first element. These values will get
- // clamped below (Note: We could have set them to 0 and axis_size-1, but
- // use lowest() and max() to maintain symmetry with StopForAxis())
- start = std::numeric_limits<int>::lowest();
- } else {
- // Backward iteration - use the last element.
- start = std::numeric_limits<int>::max();
- }
- }
-
- // Handle negative indices
- int axis_size = input_shape.sizes[axis];
- if (start < 0) {
- start += axis_size;
- }
-
- // Clamping
- start = Clamp(start, 0, axis_size - 1);
-
- return start;
-}
-
-inline int StopForAxis(int end_mask, const std::vector<int>& stop_indices,
- const std::vector<int>& strides,
- const Dims<4>& input_shape, int axis) {
- // Begin with the specified index
- int stop = stop_indices[axis];
-
- // end_mask override
- if (end_mask & (1 << axis)) {
- if (strides[axis] > 0) {
- // Forward iteration - use the last element. These values will get
- // clamped below
- stop = std::numeric_limits<int>::max();
- } else {
- // Backward iteration - use the first element.
- stop = std::numeric_limits<int>::lowest();
- }
- }
-
- // Handle negative indices
- int axis_size = input_shape.sizes[axis];
- if (stop < 0) {
- stop += axis_size;
- }
-
- // Clamping
- // Because the end index points one past the last element, we need slightly
- // different clamping ranges depending on the direction.
- if (strides[axis] > 0) {
- // Forward iteration
- stop = Clamp(stop, 0, axis_size);
- } else {
- // Backward iteration
- stop = Clamp(stop, -1, axis_size - 1);
- }
-
- return stop;
-}
-
-inline bool LoopCondition(int index, int stop, int stride) {
- // True when we have reached the end of an axis and should loop.
- return stride > 0 ? index >= stop : index <= stop;
-}
-
+// UNOPTIMIZED COPY of StridedSlice from reference_ops.h.
template <typename T>
inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
int begin_mask, int end_mask,
TFLITE_DCHECK_EQ(start_indices.size(), 4);
TFLITE_DCHECK_EQ(stop_indices.size(), 4);
TFLITE_DCHECK_EQ(strides.size(), 4);
- const int start_b =
- StartForAxis(begin_mask, start_indices, strides, input_dims, 3);
- const int stop_b =
- StopForAxis(end_mask, stop_indices, strides, input_dims, 3);
- const int start_h =
- StartForAxis(begin_mask, start_indices, strides, input_dims, 2);
- const int stop_h =
- StopForAxis(end_mask, stop_indices, strides, input_dims, 2);
- const int start_w =
- StartForAxis(begin_mask, start_indices, strides, input_dims, 1);
- const int stop_w =
- StopForAxis(end_mask, stop_indices, strides, input_dims, 1);
- const int start_d =
- StartForAxis(begin_mask, start_indices, strides, input_dims, 0);
- const int stop_d =
- StopForAxis(end_mask, stop_indices, strides, input_dims, 0);
+ const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
+ strides, input_dims.sizes, 3);
+ const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+ input_dims.sizes, 3);
+ const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
+ strides, input_dims.sizes, 2);
+ const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+ input_dims.sizes, 2);
+ const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
+ strides, input_dims.sizes, 1);
+ const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+ input_dims.sizes, 1);
+ const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
+ strides, input_dims.sizes, 0);
+ const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+ input_dims.sizes, 0);
T* out_ptr = output_data;
- for (int in_b = start_b; !LoopCondition(in_b, stop_b, strides[3]);
+ for (int in_b = start_b;
+ !strided_slice::LoopCondition(in_b, stop_b, strides[3]);
in_b += strides[3]) {
- for (int in_h = start_h; !LoopCondition(in_h, stop_h, strides[2]);
+ for (int in_h = start_h;
+ !strided_slice::LoopCondition(in_h, stop_h, strides[2]);
in_h += strides[2]) {
- for (int in_w = start_w; !LoopCondition(in_w, stop_w, strides[1]);
+ for (int in_w = start_w;
+ !strided_slice::LoopCondition(in_w, stop_w, strides[1]);
in_w += strides[1]) {
- for (int in_d = start_d; !LoopCondition(in_d, stop_d, strides[0]);
+ for (int in_d = start_d;
+ !strided_slice::LoopCondition(in_d, stop_d, strides[0]);
in_d += strides[0]) {
*out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
}
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
output_dims, 0);
}
-// STRIDED SLICE
-// The functions below for StridedSlice are mirrored in a number of places:
-//
-// propagate_fixed_sizes.cc
-// propagate_shapes.cc
-// resolve_constant_strided_slice.cc
-// optimized_ops.h
-//
-// It is designed for an arbitrary number of dimensions, even though dimensions
-// here are fixed at 4. This is because we expect to eventually support
-// arbitrary dimensionality. Also note that the axis orders are reversed for
-// runtime ops, and so the indices and masks must be as well too.
-//
-// Be warned this code involves some rather subtle logic of python slicing. The
-// best "ground truth" is to compare results to actual python execution.
-
-// Use until std::clamp() is available from C++17.
-inline int Clamp(const int v, const int lo, const int hi) {
- TFLITE_DCHECK(!(hi < lo));
- if (hi < v) return hi;
- if (v < lo) return lo;
- return v;
-}
-
-inline int StartForAxis(int begin_mask, const std::vector<int>& start_indices,
- const std::vector<int>& strides,
- const Dims<4>& input_shape, int axis) {
- // Begin with the specified index
- int start = start_indices[axis];
-
- // begin_mask override
- if (begin_mask & 1 << axis) {
- if (strides[axis] > 0) {
- // Forward iteration - use the first element. These values will get
- // clamped below (Note: We could have set them to 0 and axis_size-1, but
- // use lowest() and max() to maintain symmetry with StopForAxis())
- start = std::numeric_limits<int>::lowest();
- } else {
- // Backward iteration - use the last element.
- start = std::numeric_limits<int>::max();
- }
- }
-
- // Handle negative indices
- int axis_size = input_shape.sizes[axis];
- if (start < 0) {
- start += axis_size;
- }
-
- // Clamping
- start = Clamp(start, 0, axis_size - 1);
-
- return start;
-}
-
-inline int StopForAxis(int end_mask, const std::vector<int>& stop_indices,
- const std::vector<int>& strides,
- const Dims<4>& input_shape, int axis) {
- // Begin with the specified index
- int stop = stop_indices[axis];
-
- // end_mask override
- if (end_mask & (1 << axis)) {
- if (strides[axis] > 0) {
- // Forward iteration - use the last element. These values will get
- // clamped below
- stop = std::numeric_limits<int>::max();
- } else {
- // Backward iteration - use the first element.
- stop = std::numeric_limits<int>::lowest();
- }
- }
-
- // Handle negative indices
- int axis_size = input_shape.sizes[axis];
- if (stop < 0) {
- stop += axis_size;
- }
-
- // Clamping
- // Because the end index points one past the last element, we need slightly
- // different clamping ranges depending on the direction.
- if (strides[axis] > 0) {
- // Forward iteration
- stop = Clamp(stop, 0, axis_size);
- } else {
- // Backward iteration
- stop = Clamp(stop, -1, axis_size - 1);
- }
-
- return stop;
-}
-
-inline bool LoopCondition(int index, int stop, int stride) {
- // True when we have reached the end of an axis and should loop.
- return stride > 0 ? index >= stop : index <= stop;
-}
-
template <typename T>
inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
int begin_mask, int end_mask,
const std::vector<int>& stop_indices,
const std::vector<int>& strides, T* output_data,
const Dims<4>& output_dims) {
+ // Note that the axis orders are reversed for runtime ops, so the indices,
+ // strides and masks must be as well too.
TFLITE_DCHECK_EQ(start_indices.size(), 4);
TFLITE_DCHECK_EQ(stop_indices.size(), 4);
TFLITE_DCHECK_EQ(strides.size(), 4);
- const int start_b =
- StartForAxis(begin_mask, start_indices, strides, input_dims, 3);
- const int stop_b =
- StopForAxis(end_mask, stop_indices, strides, input_dims, 3);
- const int start_h =
- StartForAxis(begin_mask, start_indices, strides, input_dims, 2);
- const int stop_h =
- StopForAxis(end_mask, stop_indices, strides, input_dims, 2);
- const int start_w =
- StartForAxis(begin_mask, start_indices, strides, input_dims, 1);
- const int stop_w =
- StopForAxis(end_mask, stop_indices, strides, input_dims, 1);
- const int start_d =
- StartForAxis(begin_mask, start_indices, strides, input_dims, 0);
- const int stop_d =
- StopForAxis(end_mask, stop_indices, strides, input_dims, 0);
+ const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
+ strides, input_dims.sizes, 3);
+ const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+ input_dims.sizes, 3);
+ const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
+ strides, input_dims.sizes, 2);
+ const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+ input_dims.sizes, 2);
+ const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
+ strides, input_dims.sizes, 1);
+ const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+ input_dims.sizes, 1);
+ const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
+ strides, input_dims.sizes, 0);
+ const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides,
+ input_dims.sizes, 0);
T* out_ptr = output_data;
- for (int in_b = start_b; !LoopCondition(in_b, stop_b, strides[3]);
+ for (int in_b = start_b;
+ !strided_slice::LoopCondition(in_b, stop_b, strides[3]);
in_b += strides[3]) {
- for (int in_h = start_h; !LoopCondition(in_h, stop_h, strides[2]);
+ for (int in_h = start_h;
+ !strided_slice::LoopCondition(in_h, stop_h, strides[2]);
in_h += strides[2]) {
- for (int in_w = start_w; !LoopCondition(in_w, stop_w, strides[1]);
+ for (int in_w = start_w;
+ !strided_slice::LoopCondition(in_w, stop_w, strides[1]);
in_w += strides[1]) {
- for (int in_d = start_d; !LoopCondition(in_d, stop_d, strides[0]);
+ for (int in_d = start_d;
+ !strided_slice::LoopCondition(in_d, stop_d, strides[0]);
in_d += strides[0]) {
*out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
}
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
+
+#include <limits>
+#include <vector>
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+
+namespace tflite {
+
+namespace strided_slice {
+
+// Use until std::clamp() is available from C++17.
+inline int Clamp(const int v, const int lo, const int hi) {
+ TFLITE_DCHECK(!(hi < lo));
+ if (hi < v) return hi;
+ if (v < lo) return lo;
+ return v;
+}
+
+// Return the index for the first element along that axis. This index will be a
+// positive integer between [0, axis_size - 1] that can be used to index
+// directly into the data.
+template <typename IntType>
+inline int StartForAxis(int begin_mask,
+ std::vector<IntType> const& start_indices,
+ std::vector<IntType> const& strides,
+ int const* input_shape, int axis) {
+ // Begin with the specified index
+ int start = start_indices[axis];
+
+ // begin_mask override
+ if (begin_mask & 1 << axis) {
+ if (strides[axis] > 0) {
+ // Forward iteration - use the first element. These values will get
+ // clamped below (Note: We could have set them to 0 and axis_size-1, but
+ // use lowest() and max() to maintain symmetry with StopForAxis())
+ start = std::numeric_limits<int>::lowest();
+ } else {
+ // Backward iteration - use the last element.
+ start = std::numeric_limits<int>::max();
+ }
+ }
+
+ // Handle negative indices
+ int axis_size = input_shape[axis];
+ if (start < 0) {
+ start += axis_size;
+ }
+
+ // Clamping
+ start = Clamp(start, 0, axis_size - 1);
+
+ return start;
+}
+
+// Return the "real" index for the end of iteration along that axis. This is an
+// "end" in the traditional C sense, in that it points to one past the last
+// element. ie. So if you were iterating through all elements of a 1D array of
+// size 4, this function would return 4 as the stop, because it is one past the
+// "real" indices of 0, 1, 2 & 3.
+template <typename IntType>
+inline int StopForAxis(int end_mask, std::vector<IntType> const& stop_indices,
+ std::vector<IntType> const& strides,
+ int const* input_shape, int axis) {
+ // Begin with the specified index
+ int stop = stop_indices[axis];
+
+ // end_mask override
+ if (end_mask & (1 << axis)) {
+ if (strides[axis] > 0) {
+ // Forward iteration - use the last element. These values will get
+ // clamped below
+ stop = std::numeric_limits<int>::max();
+ } else {
+ // Backward iteration - use the first element.
+ stop = std::numeric_limits<int>::lowest();
+ }
+ }
+
+ // Handle negative indices
+ int axis_size = input_shape[axis];
+ if (stop < 0) {
+ stop += axis_size;
+ }
+
+ // Clamping
+ // Because the end index points one past the last element, we need slightly
+ // different clamping ranges depending on the direction.
+ if (strides[axis] > 0) {
+ // Forward iteration
+ stop = Clamp(stop, 0, axis_size);
+ } else {
+ // Backward iteration
+ stop = Clamp(stop, -1, axis_size - 1);
+ }
+
+ return stop;
+}
+
+inline bool LoopCondition(int index, int stop, int stride) {
+ // True when we have reached the end of an axis and should loop.
+ return stride > 0 ? index >= stop : index <= stop;
+}
+
+} // namespace strided_slice
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
"shrink_axis_mask": [None, 1, 2, 3, -1],
"constant_indices": [False, True],
},
+ # 1-D Exhaustive
+ {
+ "dtype": [tf.float32],
+ "index_type": [tf.int32],
+ "input_shape": [[4]],
+ "begin": [[-100], [-3], [-2], [-1], [0], [1], [2], [3], [100]],
+ "end": [[-100], [-3], [-2], [-1], [0], [1], [2], [3], [100]],
+ "strides": [-2, -1, 1, 2],
+ "begin_mask": [0, 1],
+ "end_mask": [0, 1],
+ "shrink_axis_mask": [0],
+ "constant_indices": [False],
+ },
# Negative strides
{
"dtype": [tf.float32],
":toco_port",
":tooling_util",
"//tensorflow/contrib/lite/kernels/internal:quantization_util",
+ "//tensorflow/contrib/lite/kernels/internal:strided_slice_logic",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
#include <vector>
#include "absl/strings/str_join.h"
+#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
output_array.copy_shape(*stacked_shape);
}
-// These StridedSlice utility functions are essentially a COPY of those in
-// reference_ops.h. See comments there.
-
-// Use until std::clamp() is available from C++17.
-int Clamp(const int v, const int lo, const int hi) {
- if (hi < v) return hi;
- if (v < lo) return lo;
- return v;
-}
-
-int StartForAxis(StridedSliceOperator const& op, Shape const& input_shape,
- int axis) {
- // Begin with the specified index
- int start = op.start_indices[axis];
-
- // begin_mask override
- if (op.begin_mask & 1 << axis) {
- if (op.strides[axis] > 0) {
- // Forward iteration - use the first element. These values will get
- // clamped below (Note: We could have set them to 0 and axis_size-1, but
- // use lowest() and max() to maintain symmetry with StopForAxis())
- start = std::numeric_limits<int>::lowest();
- } else {
- // Backward iteration - use the last element.
- start = std::numeric_limits<int>::max();
- }
- }
-
- // Handle negative indices
- int axis_size = input_shape.dims(axis);
- if (start < 0) {
- start += axis_size;
- }
-
- // Clamping
- start = Clamp(start, 0, axis_size - 1);
-
- return start;
-}
-
-int StopForAxis(StridedSliceOperator const& op, Shape const& input_shape,
- int axis) {
- // Begin with the specified index
- int stop = op.stop_indices[axis];
-
- // end_mask override
- if (op.end_mask & (1 << axis)) {
- if (op.strides[axis] > 0) {
- // Forward iteration - use the last element. These values will get
- // clamped below
- stop = std::numeric_limits<int>::max();
- } else {
- // Backward iteration - use the first element.
- stop = std::numeric_limits<int>::lowest();
- }
- }
-
- // Handle negative indices
- int axis_size = input_shape.dims(axis);
- if (stop < 0) {
- stop += axis_size;
- }
-
- // Clamping
- // Because the end index points one past the last element, we need slightly
- // different clamping ranges depending on the direction.
- if (op.strides[axis] > 0) {
- // Forward iteration
- stop = Clamp(stop, 0, axis_size);
- } else {
- // Backward iteration
- stop = Clamp(stop, -1, axis_size - 1);
- }
-
- return stop;
-}
-
void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
<< " has stride=" << op->strides[i] << ".";
}
- // The TensorFlow documentation is not explicit on how it handles fewer
- // supplied indices than dimensions, but they are accepted. We emulate TF's
- // behavior by fully iterating over each "forgotten" dimension.
- op->PadIndices(num_input_axes);
-
// Create output shape
std::vector<int>* dims = output_array.mutable_shape()->mutable_dims();
// Compute output shape
for (int axis = 0; axis < num_input_axes; ++axis) {
- int start_index = StartForAxis(*op, input_array.shape(), axis);
- int stop_index = StopForAxis(*op, input_array.shape(), axis);
+ int start_index = tflite::strided_slice::StartForAxis(
+ op->begin_mask, op->start_indices, op->strides,
+ input_array.shape().dims().data(), axis);
+ int stop_index = tflite::strided_slice::StopForAxis(
+ op->end_mask, op->stop_indices, op->strides,
+ input_array.shape().dims().data(), axis);
int dim_size =
ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]);
==============================================================================*/
#include <vector>
+#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
namespace {
-// These StridedSlice utility functions are essentially a COPY of those in
-// reference_ops.h. See comments there.
-
-// Use until std::clamp() is available from C++17.
-int Clamp(const int v, const int lo, const int hi) {
- if (hi < v) return hi;
- if (v < lo) return lo;
- return v;
-}
-
-int StartForAxis(StridedSliceOperator const& op, Shape const& input_shape,
- int axis) {
- // Begin with the specified index
- int start = op.start_indices[axis];
-
- // begin_mask override
- if (op.begin_mask & 1 << axis) {
- if (op.strides[axis] > 0) {
- // Forward iteration - use the first element. These values will get
- // clamped below (Note: We could have set them to 0 and axis_size-1, but
- // use lowest() and max() to maintain symmetry with StopForAxis())
- start = std::numeric_limits<int>::lowest();
- } else {
- // Backward iteration - use the last element.
- start = std::numeric_limits<int>::max();
- }
- }
-
- // Handle negative indices
- int axis_size = input_shape.dims(axis);
- if (start < 0) {
- start += axis_size;
- }
-
- // Clamping
- start = Clamp(start, 0, axis_size - 1);
-
- return start;
-}
-
-int StopForAxis(StridedSliceOperator const& op, Shape const& input_shape,
- int axis) {
- // Begin with the specified index
- int stop = op.stop_indices[axis];
-
- // end_mask override
- if (op.end_mask & (1 << axis)) {
- if (op.strides[axis] > 0) {
- // Forward iteration - use the last element. These values will get
- // clamped below
- stop = std::numeric_limits<int>::max();
- } else {
- // Backward iteration - use the first element.
- stop = std::numeric_limits<int>::lowest();
- }
- }
-
- // Handle negative indices
- int axis_size = input_shape.dims(axis);
- if (stop < 0) {
- stop += axis_size;
- }
-
- // Clamping
- // Because the end index points one past the last element, we need slightly
- // different clamping ranges depending on the direction.
- if (op.strides[axis] > 0) {
- // Forward iteration
- stop = Clamp(stop, 0, axis_size);
- } else {
- // Backward iteration
- stop = Clamp(stop, -1, axis_size - 1);
- }
-
- return stop;
-}
-
-bool LoopCondition(int index, int stop, int stride) {
- // True when we have reached the end of an axis and should loop.
- return stride > 0 ? index >= stop : index <= stop;
-}
-
template <ArrayDataType Type>
void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
Array* output_array) {
Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>();
std::vector<int> src_coord(op.start_indices.size());
for (int axis = 0; axis < num_input_axes; axis++) {
- src_coord[axis] = StartForAxis(op, input_shape, axis);
+ src_coord[axis] = tflite::strided_slice::StartForAxis(
+ op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(),
+ axis);
}
// In order to handle any number (N) of dimensions, we copy elements one by
}
// Check if we've overflowed.
- int stop = StopForAxis(op, input_shape, axis);
- if (LoopCondition(src_coord[axis], stop, stride)) {
+ int stop = tflite::strided_slice::StopForAxis(
+ op.end_mask, op.stop_indices, op.strides, input_shape.dims().data(),
+ axis);
+ if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) {
// Reset axis and set carry
- src_coord[axis] = StartForAxis(op, input_shape, axis);
+ src_coord[axis] = tflite::strided_slice::StartForAxis(
+ op.begin_mask, op.start_indices, op.strides,
+ input_shape.dims().data(), axis);
carry = true;
} else {
carry = false;
}
CHECK_EQ(op->inputs.size(), 4);
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.has_shape()) {
+ // We require the dimensionality of the input to pad the indices
+ return false;
+ }
+
const auto& start_array = model->GetArray(op->inputs[1]);
if (!start_array.has_shape()) return false;
if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) {
CHECK_EQ(op->stop_indices.size(), op->start_indices.size());
CHECK_EQ(op->strides.size(), op->stop_indices.size());
+ // The TensorFlow documentation is not explicit on how it handles fewer
+ // supplied indices than dimensions, but they are accepted. We emulate TF's
+ // behavior by fully iterating over each omitted dimension.
+ int num_input_axes = input_array.shape().dimensions_count();
+ CHECK_LE(op->start_indices.size(), num_input_axes)
+ << "StridedSlice op requires no more than " << num_input_axes
+ << " start indices";
+ CHECK_LE(op->stop_indices.size(), num_input_axes)
+ << "StridedSlice op requires no more than " << num_input_axes
+ << " stop indices";
+ CHECK_LE(op->strides.size(), num_input_axes)
+ << "StridedSlice op requires no more than " << num_input_axes
+ << " strides";
+ op->PadIndices(num_input_axes);
+
// Ideally, we would remove the input arrays after they have been resolved.
// However, we must then reconstitute these input arrays for all supported
// export formats. For now, leave the arrays so we don't have to modify our