using ::int32;
using ::testing::ElementsAreArray;
+template <typename input_type = float,
+ TensorType tensor_input_type = TensorType_FLOAT32>
class StridedSliceOpModel : public SingleOpModel {
public:
StridedSliceOpModel(std::initializer_list<int> input_shape,
std::initializer_list<int> strides_shape, int begin_mask,
int end_mask, int ellipsis_mask, int new_axis_mask,
int shrink_axis_mask) {
- input_ = AddInput(TensorType_FLOAT32);
+ input_ = AddInput(tensor_input_type);
begin_ = AddInput(TensorType_INT32);
end_ = AddInput(TensorType_INT32);
strides_ = AddInput(TensorType_INT32);
- output_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(tensor_input_type);
SetBuiltinOp(
BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions,
CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask,
BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape});
}
- void SetInput(std::initializer_list<float> data) {
- PopulateTensor<float>(input_, data);
+ void SetInput(std::initializer_list<input_type> data) {
+ PopulateTensor<input_type>(input_, data);
}
void SetBegin(std::initializer_list<int32> data) {
PopulateTensor<int32>(begin_, data);
PopulateTensor<int32>(strides_, data);
}
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<input_type> GetOutput() {
+ return ExtractVector<input_type>(output_);
+ }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
TEST(StridedSliceOpTest, UnsupportedInputSize) {
EXPECT_DEATH(
- StridedSliceOpModel({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0),
+ StridedSliceOpModel<>({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0),
"StridedSlice op only supports 1D-4D input arrays.");
}
TEST(StridedSliceOpTest, UnssupportedArgs) {
- EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0),
+ EXPECT_DEATH(StridedSliceOpModel<>({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0),
"ellipsis_mask is not implemented yet.");
- EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0),
+ EXPECT_DEATH(StridedSliceOpModel<>({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0),
"new_axis_mask is not implemented yet.");
}
TEST(StridedSliceOpTest, In1D) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({1});
m.SetEnd({3});
}
TEST(StridedSliceOpTest, In1D_EmptyOutput) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({10});
m.SetEnd({3});
}
TEST(StridedSliceOpTest, In1D_NegativeBegin) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-3});
m.SetEnd({3});
}
TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-5});
m.SetEnd({3});
}
TEST(StridedSliceOpTest, In1D_NegativeEnd) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({1});
m.SetEnd({-2});
}
TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-3});
m.SetEnd({5});
}
TEST(StridedSliceOpTest, In1D_BeginMask) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({1});
m.SetEnd({3});
}
TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-2});
m.SetEnd({-3});
}
TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({5});
m.SetEnd({2});
}
TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({2});
m.SetEnd({-4});
}
TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-3});
m.SetEnd({-5});
}
TEST(StridedSliceOpTest, In1D_EndMask) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({1});
m.SetEnd({3});
}
TEST(StridedSliceOpTest, In1D_NegStride) {
- StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3});
m.SetBegin({-1});
m.SetEnd({-4});
}
TEST(StridedSliceOpTest, In1D_EvenLenStride2) {
- StridedSliceOpModel m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2});
m.SetBegin({0});
m.SetEnd({2});
}
TEST(StridedSliceOpTest, In1D_OddLenStride2) {
- StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3});
m.SetBegin({0});
m.SetEnd({3});
}
TEST(StridedSliceOpTest, In2D_Identity) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({0, 0});
m.SetEnd({2, 3});
}
TEST(StridedSliceOpTest, In2D) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({1, 0});
m.SetEnd({2, 2});
}
TEST(StridedSliceOpTest, In2D_Stride2) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({0, 0});
m.SetEnd({2, 3});
}
TEST(StridedSliceOpTest, In2D_NegStride) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({1, -1});
m.SetEnd({2, -4});
}
TEST(StridedSliceOpTest, In2D_BeginMask) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({1, 0});
m.SetEnd({2, 2});
}
TEST(StridedSliceOpTest, In2D_EndMask) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({1, 0});
m.SetEnd({2, 2});
}
TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({1, -2});
m.SetEnd({2, -4});
}
TEST(StridedSliceOpTest, In2D_NegStrideEndMask) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({1, -2});
m.SetEnd({2, -3});
}
TEST(StridedSliceOpTest, In3D_Identity) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
}
TEST(StridedSliceOpTest, In3D_NegStride) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({-1, -1, -1});
m.SetEnd({-3, -4, -3});
}
TEST(StridedSliceOpTest, In3D_Strided2) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
}
TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4});
m.SetBegin({1});
m.SetEnd({3});
}
TEST(StridedSliceOpTest, In1D_EmptyOutputShrinkAxisMask1) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4});
m.SetBegin({2});
m.SetEnd({1});
}
TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4});
m.SetBegin({1});
m.SetEnd({3});
}
TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-2});
m.SetEnd({-3});
}
TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({0, 0});
m.SetEnd({2, 3});
}
TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({0, 0});
m.SetEnd({2, 3});
}
TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({0, 0});
m.SetEnd({2, 3});
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
// This tests catches a very subtle bug that was fixed by cl/188403234.
TEST(StridedSliceOpTest, RunTwice) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
auto setup_inputs = [&m]() {
m.SetInput({1, 2, 3, 4, 5, 6});
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5}));
}
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) {
+ StridedSliceOpModel<uint8, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0,
+ 0, 0, 1);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetEnd({2, 3, 2});
+ m.SetStrides({1, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
} // namespace
} // namespace tflite