6 #include <boost/test/unit_test.hpp> 8 #include "../TfLiteParser.hpp" 16 explicit StridedSliceFixture(
const std::string & inputShape,
17 const std::string & outputShape,
18 const std::string & beginData,
19 const std::string & endData,
20 const std::string & stridesData,
27 "operator_codes": [ { "builtin_code": "STRIDED_SLICE" } ], 31 "shape": )" + inputShape + R"(, 34 "name": "inputTensor", 46 "name": "beginTensor", 62 "name": "stridesTensor", 67 "shape": )" + outputShape + R"( , 70 "name": "outputTensor", 79 "inputs": [ 0, 1, 2, 3 ], 84 "inputs": [ 0, 1, 2, 3 ], 86 "builtin_options_type": "StridedSliceOptions", 88 "begin_mask": )" + std::to_string(beginMask) + R"(, 89 "end_mask": )" + std::to_string(endMask) + R"( 91 "custom_options_format": "FLEXBUFFERS" 97 { "data": )" + beginData + R"(, }, 98 { "data": )" + endData + R"(, }, 99 { "data": )" + stridesData + R"(, }, 108 struct StridedSlice4DFixture : StridedSliceFixture
110 StridedSlice4DFixture() : StridedSliceFixture(
"[ 3, 2, 3, 1 ]",
112 "[ 1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]",
113 "[ 2,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]",
114 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]" 120 RunTest<4, armnn::DataType::Float32>(
122 {{
"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
124 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
126 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
128 {{
"outputTensor", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
131 struct StridedSlice4DReverseFixture : StridedSliceFixture
133 StridedSlice4DReverseFixture() : StridedSliceFixture(
"[ 3, 2, 3, 1 ]",
152 RunTest<4, armnn::DataType::Float32>(
154 {{
"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
156 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
158 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
160 {{
"outputTensor", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
163 struct StridedSliceSimpleStrideFixture : StridedSliceFixture
165 StridedSliceSimpleStrideFixture() : StridedSliceFixture(
"[ 3, 2, 3, 1 ]",
167 "[ 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]",
168 "[ 3,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]",
169 "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 1,0,0,0 ]" 175 RunTest<4, armnn::DataType::Float32>(
177 {{
"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
179 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
181 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
183 {{
"outputTensor", { 1.0f, 1.0f,
188 struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
190 StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture(
"[ 3, 2, 3, 1 ]",
192 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]",
193 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]",
194 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]",
202 RunTest<4, armnn::DataType::Float32>(
204 {{
"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
206 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
208 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
210 {{
"outputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
212 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
214 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
BOOST_AUTO_TEST_SUITE_END()