2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
11 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
13 struct SimpleConv2DFixture : public ParserFlatbuffersFixture
15 explicit SimpleConv2DFixture()
20 "operator_codes": [ { "builtin_code": "CONV_2D" } ],
24 "shape": [ 1, 3, 3, 1 ],
27 "name": "inputTensor",
36 "shape": [ 1, 1, 1, 1 ],
39 "name": "outputTensor",
48 "shape": [ 1, 3, 3, 1 ],
51 "name": "filterTensor",
67 "builtin_options_type": "Conv2DOptions",
72 "fused_activation_function": "NONE"
74 "custom_options_format": "FLEXBUFFERS"
81 { "data": [ 2,1,0, 6,2,1, 4,1,2 ], },
86 SetupSingleInputSingleOutput("inputTensor", "outputTensor");
90 BOOST_FIXTURE_TEST_CASE( ParseSimpleConv2D, SimpleConv2DFixture )
92 RunTest<4, armnn::DataType::QuantisedAsymm8>(
99 // because of the output scaling we need to take half of the values
107 struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
109 explicit Conv2DWithBiasesFixture(const std::string & inputShape,
110 const std::string & outputShape,
111 const std::string & filterShape,
112 const std::string & filterData,
113 const std::string & biasShape,
114 const std::string & biasData,
115 const std::string & strides,
116 const std::string & activation="NONE",
117 const std::string & filterScale="1.0",
118 const std::string & filterZeroPoint="0",
119 const std::string & outputScale="2.0",
120 const std::string & outputZeroPoint="0")
125 "operator_codes": [ { "builtin_code": "CONV_2D" } ],
129 "shape": )" + inputShape + R"(,
132 "name": "inputTensor",
141 "shape": )" + outputShape + R"(,
144 "name": "outputTensor",
148 "scale": [ )" + outputScale + R"( ],
149 "zero_point": [ )" + outputZeroPoint + R"( ],
153 "shape": )" + filterShape + R"( ,
156 "name": "filterTensor",
160 "scale": [ )" + filterScale + R"( ],
161 "zero_point": [ )" + filterZeroPoint + R"( ],
165 "shape": )" + biasShape + R"( ,
168 "name": "biasTensor",
182 "inputs": [ 0, 2, 3 ],
184 "builtin_options_type": "Conv2DOptions",
187 "stride_w": )" + strides + R"(,
188 "stride_h": )" + strides + R"(,
189 "fused_activation_function": )" + activation + R"(
191 "custom_options_format": "FLEXBUFFERS"
198 { "data": )" + filterData + R"(, },
199 { "data": )" + biasData + R"(, },
203 SetupSingleInputSingleOutput("inputTensor", "outputTensor");
207 struct SimpleConv2DWithBiasesFixture : Conv2DWithBiasesFixture
209 SimpleConv2DWithBiasesFixture()
210 : Conv2DWithBiasesFixture("[ 1, 2, 2, 1 ]", // inputShape
211 "[ 1, 2, 2, 1 ]", // outputShape
212 "[ 1, 2, 2, 1 ]", // filterShape
213 "[ 2,1, 0,6 ]", // filterData
214 "[ 1 ]", // biasShape
215 "[ 10, 0, 0, 0 ]", // biasData
216 "1") // stride w and h
220 BOOST_FIXTURE_TEST_CASE( ParseConv2DWithBias, SimpleConv2DWithBiasesFixture )
222 RunTest<4, armnn::DataType::QuantisedAsymm8>(
228 // because of the output scaling we need to take half of the values
230 (1*2 + 2*1 + 3*0 + 4*6 + 10)/2,
231 (2*2 + 0*1 + 4*0 + 0*6 + 10)/2,
232 (3*2 + 4*1 + 0*0 + 0*6 + 10)/2,
233 (4*2 + 0*1 + 0*0 + 0*6 + 10)/2
237 struct Conv2DShapeTestFixture : Conv2DWithBiasesFixture
239 static std::string GenerateInts(unsigned int n)
241 std::stringstream ss;
243 for( unsigned int i=0; i<n; ++i ) {
248 ss << " " << (i%256);
254 Conv2DShapeTestFixture()
255 : Conv2DWithBiasesFixture("[ 1, 224, 224, 3 ]", // inputShape
256 "[ 1, 112, 112, 32 ]", // outputShape
257 "[ 32, 3, 3, 3 ]", // filterShape
258 GenerateInts(32*3*3*3), // filterData
259 "[ 32 ]", // biasShape
260 GenerateInts(32*4), // biasData
261 "2") // stride w and h
265 BOOST_FIXTURE_TEST_CASE( ParseConv2D_112x112_out, Conv2DShapeTestFixture )
269 struct ReluConv2DWithBiasesFixture : Conv2DWithBiasesFixture
271 ReluConv2DWithBiasesFixture()
272 : Conv2DWithBiasesFixture("[ 1, 2, 2, 1 ]", // inputShape
273 "[ 1, 2, 2, 1 ]", // outputShape
274 "[ 1, 2, 2, 1 ]", // filterShape
275 "[ 2,1, 0,6 ]", // filterData
276 "[ 1 ]", // biasShape
277 "[ 16, 0, 0, 0 ]", // biasData
278 "1", // stride w and h
279 "RELU", // activation
280 "1.0", // filter scale
281 "4", // filter zero point
282 "2.0", // output scale
283 "20") // output zero point
287 BOOST_FIXTURE_TEST_CASE( ParseConv2DAndReluWithBias, ReluConv2DWithBiasesFixture )
290 uint8_t outZero = 20;
291 uint8_t fz = 4; // filter zero point
293 RunTest<4, armnn::DataType::QuantisedAsymm8>(
299 // factors to consider:
300 // - the filter zero point is non zero, hence the (x-fz)
301 // - the output scale is 2 hence the /2
302 // - output zero point is non zero, hence the +outZero
303 // - RELU cuts negative values and then we add the output zero point
305 std::max(outZero, static_cast<uint8_t>((1*(2-fz) + 2*(1-fz) + 4*(0-fz) + 8*(6-fz) + bias)/2 + outZero)),
306 std::max(outZero, static_cast<uint8_t>((2*(2-fz) + 0*(1-fz) + 8*(0-fz) + 0*(6-fz) + bias)/2 + outZero)),
307 std::max(outZero, static_cast<uint8_t>((4*(2-fz) + 8*(1-fz) + 0*(0-fz) + 0*(6-fz) + bias)/2 + outZero)),
308 std::max(outZero, static_cast<uint8_t>((8*(2-fz) + 0*(1-fz) + 0*(0-fz) + 0*(6-fz) + bias)/2 + outZero))
312 struct Relu6Conv2DWithBiasesFixture : Conv2DWithBiasesFixture
314 Relu6Conv2DWithBiasesFixture()
315 : Conv2DWithBiasesFixture("[ 1, 2, 2, 1 ]", // inputShape
316 "[ 1, 2, 2, 1 ]", // outputShape
317 "[ 1, 2, 2, 1 ]", // filterShape
318 "[ 2,1, 0,6 ]", // filterData
319 "[ 1 ]", // biasShape
320 "[ 0, 0, 0, 0 ]", // biasData
321 "1", // stride w and h
322 "RELU6", // activation
323 "1.0", // filter scale
324 "0", // filter zero point
325 "2.0", // output scale
326 "0") // output zero point
330 BOOST_FIXTURE_TEST_CASE( ParseConv2DAndRelu6WithBias, Relu6Conv2DWithBiasesFixture )
332 uint8_t relu6Min = 6 / 2; // divide by output scale
334 RunTest<4, armnn::DataType::QuantisedAsymm8>(
340 // factors to consider:
341 // - the output scale is 2 hence the /2
342 // - RELU6 cuts output values at +6
344 std::min(relu6Min, static_cast<uint8_t>((1*2 + 2*1 + 4*0 + 1*6)/2)),
345 std::min(relu6Min, static_cast<uint8_t>((2*2 + 0*1 + 1*0 + 0*6)/2)),
346 std::min(relu6Min, static_cast<uint8_t>((4*2 + 1*1 + 0*0 + 0*6)/2)),
347 std::min(relu6Min, static_cast<uint8_t>((1*2 + 0*1 + 0*0 + 0*6)/2))
351 BOOST_AUTO_TEST_SUITE_END()