return masked ? (stride > 0 ? 0 : dim - 1) : start;
}
-inline int StopIndex(int stop, int stride, int dim, bool masked) {
- return masked ? (stride > 0 ? dim : -1) : stop;
+inline int StopIndex(int start, int stop, int stride, int dim, bool masked,
+ bool shrink_axis_masked) {
+ return shrink_axis_masked ? stride > 0 ? start + 1 : start - 1
+ : masked ? (stride > 0 ? dim : -1) : stop;
}
template <typename T>
inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
- int begin_mask, int end_mask,
+ int begin_mask, int end_mask, int shrink_axis_mask,
const std::vector<int>& starts,
const std::vector<int>& stops,
const std::vector<int>& strides, T* output_data,
const int start_b =
StartIndex(starts[3], strides[3], input_dims.sizes[3], begin_mask & 8);
const int stop_b =
- StopIndex(stops[3], strides[3], input_dims.sizes[3], end_mask & 8);
+ StopIndex(start_b, stops[3], strides[3], input_dims.sizes[3],
+ end_mask & 8, shrink_axis_mask & 8);
const int start_h =
StartIndex(starts[2], strides[2], input_dims.sizes[2], begin_mask & 4);
const int stop_h =
- StopIndex(stops[2], strides[2], input_dims.sizes[2], end_mask & 4);
+ StopIndex(start_h, stops[2], strides[2], input_dims.sizes[2],
+ end_mask & 4, shrink_axis_mask & 4);
const int start_w =
StartIndex(starts[1], strides[1], input_dims.sizes[1], begin_mask & 2);
const int stop_w =
- StopIndex(stops[1], strides[1], input_dims.sizes[1], end_mask & 2);
+ StopIndex(start_w, stops[1], strides[1], input_dims.sizes[1],
+ end_mask & 2, shrink_axis_mask & 2);
const int start_d =
StartIndex(starts[0], strides[0], input_dims.sizes[0], begin_mask & 1);
const int stop_d =
- StopIndex(stops[0], strides[0], input_dims.sizes[0], end_mask & 1);
+ StopIndex(start_d, stops[0], strides[0], input_dims.sizes[0],
+ end_mask & 1, shrink_axis_mask & 1);
T* out_ptr = output_data;
for (int in_b = start_b; LoopCondition(in_b, stop_b, strides[3]);
}
}
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+ int begin_mask, int end_mask,
+ const std::vector<int>& starts,
+ const std::vector<int>& stops,
+ const std::vector<int>& strides, T* output_data,
+ const Dims<4>& output_dims) {
+ StridedSlice(input_data, input_dims, begin_mask, end_mask,
+ /*shrink_axis_mask=*/0, starts, stops, strides, output_data,
+ output_dims);
+}
+
template <typename T>
inline void Slice(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& begin, const std::vector<int>& size,
"ellipsis_mask is not implemented yet.");
TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0,
"new_axis_mask is not implemented yet.");
- TF_LITE_ENSURE_MSG(context, op_context.params->shrink_axis_mask == 0,
- "shrink_axis_mask is not implemented yet.");
// TODO(soroosh): optimize for constant tensors to do allocation in Prepare
op_context.output->allocation_type = kTfLiteDynamic;
std::vector<int> starts;
std::vector<int> stops;
std::vector<int> strides;
+ std::vector<int> output_shape_vector;
- // Determine size of output tensor and map indices
- TfLiteIntArray* output_shape = TfLiteIntArrayCreate(op_context.dims);
for (int idx = op_context.dims - 1; idx >= 0; --idx) {
int dim = op_context.input->dims->data[idx];
int32_t stride = GetTensorData<int32_t>(op_context.strides)[idx];
pos_stride);
// This is valid for both positive and negative strides
- output_shape->data[idx] = ceil((end - begin) / static_cast<float>(stride));
- output_shape->data[idx] =
- output_shape->data[idx] < 0 ? 0 : output_shape->data[idx];
+ int32_t dim_shape = ceil((end - begin) / static_cast<float>(stride));
+ dim_shape = dim_shape < 0 ? 0 : dim_shape;
+
+ if (!(op_context.params->shrink_axis_mask & (1 << idx))) {
+ output_shape_vector.push_back(dim_shape);
+ }
+
starts.emplace_back(begin);
stops.emplace_back(end);
strides.emplace_back(stride);
}
+ TfLiteIntArray* output_shape =
+ TfLiteIntArrayCreate(output_shape_vector.size());
+
+ std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(),
+ output_shape->data);
+
for (int i = op_context.dims; i < kMaxDim; i++) {
starts.emplace_back(0);
stops.emplace_back(1);
ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
op_context.params->end_mask =
ReverseMaskBits(op_context.params->end_mask, op_context.dims);
-
-#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
- kernel_type::StridedSlice( \
- GetTensorData<data_type>(op_context.input), \
- GetTensorDims(op_context.input), op_context.params->begin_mask, \
- op_context.params->end_mask, starts, stops, strides, \
- GetTensorData<data_type>(op_context.output), \
+ op_context.params->shrink_axis_mask =
+ ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims);
+
+#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
+ kernel_type::StridedSlice( \
+ GetTensorData<data_type>(op_context.input), \
+ GetTensorDims(op_context.input), op_context.params->begin_mask, \
+ op_context.params->end_mask, op_context.params->shrink_axis_mask, \
+ starts, stops, strides, GetTensorData<data_type>(op_context.output), \
GetTensorDims(op_context.output))
switch (op_context.input->type) {
"ellipsis_mask is not implemented yet.");
EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0),
"new_axis_mask is not implemented yet.");
- EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 0, 1),
- "shrink_axis_mask is not implemented yet.");
}
TEST(StridedSliceOpTest, In1D) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
}
+
TEST(StridedSliceOpTest, In1D_NegStride) {
StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3});
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
}
+
TEST(StridedSliceOpTest, In1D_OddLenStride2) {
StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3});
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
+
TEST(StridedSliceOpTest, In2D) {
StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4}));
}
+
TEST(StridedSliceOpTest, In2D_NegStrideEndMask) {
StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}));
}
+
TEST(StridedSliceOpTest, In3D_Strided2) {
StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5}));
}
+TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({1});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_TRUE(m.GetOutputShape().empty());
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
+}
+
+TEST(StridedSliceOpTest, In1D_EmptyOutputShrinkAxisMask1) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({2});
+ m.SetEnd({1});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_TRUE(m.GetOutputShape().empty());
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
+}
+
+TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({1});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_TRUE(m.GetOutputShape().empty());
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
+}
+
+TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({-2});
+ m.SetEnd({-3});
+ m.SetStrides({-1});
+ m.Invoke();
+ EXPECT_TRUE(m.GetOutputShape().empty());
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
+}
+
+TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({0, 0});
+ m.SetEnd({2, 3});
+ m.SetStrides({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
+}
+
+TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({0, 0});
+ m.SetEnd({2, 3});
+ m.SetStrides({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4}));
+}
+
+TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({0, 0});
+ m.SetEnd({2, 3});
+ m.SetStrides({1, 1});
+ m.Invoke();
+ EXPECT_TRUE(m.GetOutputShape().empty());
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) {
+ StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetEnd({2, 3, 2});
+ m.SetStrides({1, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) {
+ StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetEnd({2, 3, 2});
+ m.SetStrides({1, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 7, 8}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) {
+ StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetEnd({2, 3, 2});
+ m.SetStrides({1, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) {
+ StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetEnd({2, 3, 2});
+ m.SetStrides({1, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5, 7, 9, 11}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) {
+ StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetEnd({2, 3, 2});
+ m.SetStrides({1, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) {
+ StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetEnd({2, 3, 2});
+ m.SetStrides({1, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 7}));
+}
+
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) {
+ StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetEnd({2, 3, 2});
+ m.SetStrides({1, 1, 1});
+ m.Invoke();
+ EXPECT_TRUE(m.GetOutputShape().empty());
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
+}
} // namespace
} // namespace tflite