9881b9e61e8cc850b2122f3ed86b982d7a579ffc
[platform/upstream/armnn.git] / src / armnnDeserializer / test / DeserializeComparison.cpp
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include "../Deserializer.hpp"
8
9 #include <QuantizeHelper.hpp>
10 #include <ResolveType.hpp>
11
12 #include <boost/test/unit_test.hpp>
13
14 #include <string>
15
16 BOOST_AUTO_TEST_SUITE(Deserializer)
17
18 #define DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
19 struct Simple##operation##dataType##Fixture : public SimpleComparisonFixture \
20 { \
21     Simple##operation##dataType##Fixture() \
22         : SimpleComparisonFixture(#dataType, #operation) {} \
23 };
24
25 #define DECLARE_SIMPLE_COMPARISON_TEST_CASE(operation, dataType) \
26 DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
27 BOOST_FIXTURE_TEST_CASE(operation##dataType, Simple##operation##dataType##Fixture) \
28 { \
29     using T = armnn::ResolveType<armnn::DataType::dataType>; \
30     constexpr float   qScale  = 1.f; \
31     constexpr int32_t qOffset = 0; \
32     RunTest<4, armnn::DataType::dataType, armnn::DataType::Boolean>( \
33         0, \
34         {{ "InputLayer0", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData0, qScale, qOffset)  }, \
35          { "InputLayer1", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData1, qScale, qOffset)  }}, \
36         {{ "OutputLayer", s_TestData.m_Output##operation }}); \
37 }
38
39 struct ComparisonFixture : public ParserFlatbuffersSerializeFixture
40 {
41     explicit ComparisonFixture(const std::string& inputShape0,
42                                const std::string& inputShape1,
43                                const std::string& outputShape,
44                                const std::string& inputDataType,
45                                const std::string& comparisonOperation)
46     {
47         m_JsonString = R"(
48             {
49                 inputIds: [0, 1],
50                 outputIds: [3],
51                 layers: [
52                     {
53                         layer_type: "InputLayer",
54                         layer: {
55                             base: {
56                                 layerBindingId: 0,
57                                 base: {
58                                     index: 0,
59                                     layerName: "InputLayer0",
60                                     layerType: "Input",
61                                     inputSlots: [{
62                                         index: 0,
63                                         connection: { sourceLayerIndex:0, outputSlotIndex:0 },
64                                     }],
65                                     outputSlots: [{
66                                         index: 0,
67                                         tensorInfo: {
68                                             dimensions: )" + inputShape0 + R"(,
69                                             dataType: )" + inputDataType + R"(
70                                         },
71                                     }],
72                                 },
73                             }
74                         },
75                     },
76                     {
77                         layer_type: "InputLayer",
78                         layer: {
79                             base: {
80                                 layerBindingId: 1,
81                                 base: {
82                                       index:1,
83                                       layerName: "InputLayer1",
84                                       layerType: "Input",
85                                       inputSlots: [{
86                                           index: 0,
87                                           connection: { sourceLayerIndex:0, outputSlotIndex:0 },
88                                       }],
89                                       outputSlots: [{
90                                           index: 0,
91                                           tensorInfo: {
92                                               dimensions: )" + inputShape1 + R"(,
93                                               dataType: )" + inputDataType + R"(
94                                           },
95                                       }],
96                                 },
97                             }
98                         },
99                     },
100                     {
101                         layer_type: "ComparisonLayer",
102                         layer: {
103                             base: {
104                                  index:2,
105                                  layerName: "ComparisonLayer",
106                                  layerType: "Comparison",
107                                  inputSlots: [{
108                                      index: 0,
109                                      connection: { sourceLayerIndex:0, outputSlotIndex:0 },
110                                  },
111                                  {
112                                      index: 1,
113                                      connection: { sourceLayerIndex:1, outputSlotIndex:0 },
114                                  }],
115                                  outputSlots: [{
116                                      index: 0,
117                                      tensorInfo: {
118                                          dimensions: )" + outputShape + R"(,
119                                          dataType: Boolean
120                                      },
121                                  }],
122                             },
123                             descriptor: {
124                                 operation: )" + comparisonOperation + R"(
125                             }
126                         },
127                     },
128                     {
129                         layer_type: "OutputLayer",
130                         layer: {
131                             base:{
132                                 layerBindingId: 0,
133                                 base: {
134                                     index: 3,
135                                     layerName: "OutputLayer",
136                                     layerType: "Output",
137                                     inputSlots: [{
138                                         index: 0,
139                                         connection: { sourceLayerIndex:2, outputSlotIndex:0 },
140                                     }],
141                                     outputSlots: [{
142                                         index: 0,
143                                         tensorInfo: {
144                                             dimensions: )" + outputShape + R"(,
145                                             dataType: Boolean
146                                         },
147                                     }],
148                                 }
149                             }
150                         },
151                     }
152                 ]
153             }
154         )";
155         Setup();
156     }
157 };
158
159 struct SimpleComparisonTestData
160 {
161     SimpleComparisonTestData()
162     {
163         m_InputData0 =
164         {
165             1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
166             3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f
167         };
168
169         m_InputData1 =
170         {
171             1.f, 1.f, 1.f, 1.f, 3.f, 3.f, 3.f, 3.f,
172             5.f, 5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 4.f
173         };
174
175         m_OutputEqual =
176         {
177             1, 1, 1, 1, 0, 0, 0, 0,
178             0, 0, 0, 0, 1, 1, 1, 1
179         };
180
181         m_OutputGreater =
182         {
183             0, 0, 0, 0, 1, 1, 1, 1,
184             0, 0, 0, 0, 0, 0, 0, 0
185         };
186
187         m_OutputGreaterOrEqual =
188         {
189             1, 1, 1, 1, 1, 1, 1, 1,
190             0, 0, 0, 0, 1, 1, 1, 1
191         };
192
193         m_OutputLess =
194         {
195             0, 0, 0, 0, 0, 0, 0, 0,
196             1, 1, 1, 1, 0, 0, 0, 0
197         };
198
199         m_OutputLessOrEqual =
200         {
201             1, 1, 1, 1, 0, 0, 0, 0,
202             1, 1, 1, 1, 1, 1, 1, 1
203         };
204
205         m_OutputNotEqual =
206         {
207             0, 0, 0, 0, 1, 1, 1, 1,
208             1, 1, 1, 1, 0, 0, 0, 0
209         };
210     }
211
212     std::vector<float> m_InputData0;
213     std::vector<float> m_InputData1;
214
215     std::vector<uint8_t> m_OutputEqual;
216     std::vector<uint8_t> m_OutputGreater;
217     std::vector<uint8_t> m_OutputGreaterOrEqual;
218     std::vector<uint8_t> m_OutputLess;
219     std::vector<uint8_t> m_OutputLessOrEqual;
220     std::vector<uint8_t> m_OutputNotEqual;
221 };
222
223 struct SimpleComparisonFixture : public ComparisonFixture
224 {
225     SimpleComparisonFixture(const std::string& inputDataType,
226                             const std::string& comparisonOperation)
227         : ComparisonFixture("[ 2, 2, 2, 2 ]", // inputShape0
228                             "[ 2, 2, 2, 2 ]", // inputShape1
229                             "[ 2, 2, 2, 2 ]", // outputShape,
230                             inputDataType,
231                             comparisonOperation) {}
232
233     static SimpleComparisonTestData s_TestData;
234 };
235
236 SimpleComparisonTestData SimpleComparisonFixture::s_TestData;
237
238 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal,          Float32)
239 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater,        Float32)
240 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, Float32)
241 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less,           Float32)
242 DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual,    Float32)
243 DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual,       Float32)
244
245
246 ARMNN_NO_DEPRECATE_WARN_BEGIN
247 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal,          QuantisedAsymm8)
248 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater,        QuantisedAsymm8)
249 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, QuantisedAsymm8)
250 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less,           QuantisedAsymm8)
251 DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual,    QuantisedAsymm8)
252 DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual,       QuantisedAsymm8)
253 ARMNN_NO_DEPRECATE_WARN_END
254
255 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal,          QAsymmU8)
256 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater,        QAsymmU8)
257 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, QAsymmU8)
258 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less,           QAsymmU8)
259 DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual,    QAsymmU8)
260 DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual,       QAsymmU8)
261
262 BOOST_AUTO_TEST_SUITE_END()