2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
9 #include <QuantizeHelper.hpp>
10 #include <ResolveType.hpp>
12 #include <boost/test/unit_test.hpp>
16 BOOST_AUTO_TEST_SUITE(Deserializer)
18 #define DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
19 struct Simple##operation##dataType##Fixture : public SimpleComparisonFixture \
21 Simple##operation##dataType##Fixture() \
22 : SimpleComparisonFixture(#dataType, #operation) {} \
25 #define DECLARE_SIMPLE_COMPARISON_TEST_CASE(operation, dataType) \
26 DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
27 BOOST_FIXTURE_TEST_CASE(operation##dataType, Simple##operation##dataType##Fixture) \
29 using T = armnn::ResolveType<armnn::DataType::dataType>; \
30 constexpr float qScale = 1.f; \
31 constexpr int32_t qOffset = 0; \
32 RunTest<4, armnn::DataType::dataType, armnn::DataType::Boolean>( \
34 {{ "InputLayer0", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData0, qScale, qOffset) }, \
35 { "InputLayer1", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData1, qScale, qOffset) }}, \
36 {{ "OutputLayer", s_TestData.m_Output##operation }}); \
39 struct ComparisonFixture : public ParserFlatbuffersSerializeFixture
41 explicit ComparisonFixture(const std::string& inputShape0,
42 const std::string& inputShape1,
43 const std::string& outputShape,
44 const std::string& inputDataType,
45 const std::string& comparisonOperation)
53 layer_type: "InputLayer",
59 layerName: "InputLayer0",
63 connection: { sourceLayerIndex:0, outputSlotIndex:0 },
68 dimensions: )" + inputShape0 + R"(,
69 dataType: )" + inputDataType + R"(
77 layer_type: "InputLayer",
83 layerName: "InputLayer1",
87 connection: { sourceLayerIndex:0, outputSlotIndex:0 },
92 dimensions: )" + inputShape1 + R"(,
93 dataType: )" + inputDataType + R"(
101 layer_type: "ComparisonLayer",
105 layerName: "ComparisonLayer",
106 layerType: "Comparison",
109 connection: { sourceLayerIndex:0, outputSlotIndex:0 },
113 connection: { sourceLayerIndex:1, outputSlotIndex:0 },
118 dimensions: )" + outputShape + R"(,
124 operation: )" + comparisonOperation + R"(
129 layer_type: "OutputLayer",
135 layerName: "OutputLayer",
139 connection: { sourceLayerIndex:2, outputSlotIndex:0 },
144 dimensions: )" + outputShape + R"(,
159 struct SimpleComparisonTestData
161 SimpleComparisonTestData()
165 1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
166 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f
171 1.f, 1.f, 1.f, 1.f, 3.f, 3.f, 3.f, 3.f,
172 5.f, 5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 4.f
177 1, 1, 1, 1, 0, 0, 0, 0,
178 0, 0, 0, 0, 1, 1, 1, 1
183 0, 0, 0, 0, 1, 1, 1, 1,
184 0, 0, 0, 0, 0, 0, 0, 0
187 m_OutputGreaterOrEqual =
189 1, 1, 1, 1, 1, 1, 1, 1,
190 0, 0, 0, 0, 1, 1, 1, 1
195 0, 0, 0, 0, 0, 0, 0, 0,
196 1, 1, 1, 1, 0, 0, 0, 0
199 m_OutputLessOrEqual =
201 1, 1, 1, 1, 0, 0, 0, 0,
202 1, 1, 1, 1, 1, 1, 1, 1
207 0, 0, 0, 0, 1, 1, 1, 1,
208 1, 1, 1, 1, 0, 0, 0, 0
212 std::vector<float> m_InputData0;
213 std::vector<float> m_InputData1;
215 std::vector<uint8_t> m_OutputEqual;
216 std::vector<uint8_t> m_OutputGreater;
217 std::vector<uint8_t> m_OutputGreaterOrEqual;
218 std::vector<uint8_t> m_OutputLess;
219 std::vector<uint8_t> m_OutputLessOrEqual;
220 std::vector<uint8_t> m_OutputNotEqual;
223 struct SimpleComparisonFixture : public ComparisonFixture
225 SimpleComparisonFixture(const std::string& inputDataType,
226 const std::string& comparisonOperation)
227 : ComparisonFixture("[ 2, 2, 2, 2 ]", // inputShape0
228 "[ 2, 2, 2, 2 ]", // inputShape1
229 "[ 2, 2, 2, 2 ]", // outputShape,
231 comparisonOperation) {}
233 static SimpleComparisonTestData s_TestData;
236 SimpleComparisonTestData SimpleComparisonFixture::s_TestData;
238 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal, Float32)
239 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater, Float32)
240 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, Float32)
241 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less, Float32)
242 DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual, Float32)
243 DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual, Float32)
246 ARMNN_NO_DEPRECATE_WARN_BEGIN
247 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal, QuantisedAsymm8)
248 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater, QuantisedAsymm8)
249 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, QuantisedAsymm8)
250 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less, QuantisedAsymm8)
251 DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual, QuantisedAsymm8)
252 DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual, QuantisedAsymm8)
253 ARMNN_NO_DEPRECATE_WARN_END
255 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal, QAsymmU8)
256 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater, QAsymmU8)
257 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, QAsymmU8)
258 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less, QAsymmU8)
259 DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual, QAsymmU8)
260 DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual, QAsymmU8)
262 BOOST_AUTO_TEST_SUITE_END()