Release 18.08
[platform/upstream/armnn.git] / src / armnnTfLiteParser / test / GetInputsOutputs.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include <boost/test/unit_test.hpp>
6 #include "ParserFlatbuffersFixture.hpp"
7 #include "../TfLiteParser.hpp"
8
9 using armnnTfLiteParser::TfLiteParser;
10 using ModelPtr = TfLiteParser::ModelPtr;
11
12 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
13
14 struct GetInputsOutputsMainFixture : public ParserFlatbuffersFixture
15 {
16     explicit GetInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs)
17     {
18         m_JsonString = R"(
19         {
20             "version": 3,
21             "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
22             "subgraphs": [
23             {
24                 "tensors": [
25                 {
26                     "shape": [ 1, 1, 1, 1 ] ,
27                     "type": "UINT8",
28                             "buffer": 0,
29                             "name": "OutputTensor",
30                             "quantization": {
31                                 "min": [ 0.0 ],
32                                 "max": [ 255.0 ],
33                                 "scale": [ 1.0 ],
34                                 "zero_point": [ 0 ]
35                             }
36                 },
37                 {
38                     "shape": [ 1, 2, 2, 1 ] ,
39                     "type": "UINT8",
40                             "buffer": 1,
41                             "name": "InputTensor",
42                             "quantization": {
43                                 "min": [ -1.2 ],
44                                 "max": [ 25.5 ],
45                                 "scale": [ 0.25 ],
46                                 "zero_point": [ 10 ]
47                             }
48                 }
49                 ],
50                 "inputs": [ 1 ],
51                 "outputs": [ 0 ],
52                 "operators": [ {
53                         "opcode_index": 0,
54                         "inputs":  )"
55                             + inputs
56                             + R"(,
57                         "outputs": )"
58                             + outputs
59                             + R"(,
60                         "builtin_options_type": "Pool2DOptions",
61                         "builtin_options":
62                         {
63                             "padding": "VALID",
64                             "stride_w": 2,
65                             "stride_h": 2,
66                             "filter_width": 2,
67                             "filter_height": 2,
68                             "fused_activation_function": "NONE"
69                         },
70                         "custom_options_format": "FLEXBUFFERS"
71                     } ]
72                 },
73                 {
74                     "tensors": [
75                         {
76                             "shape": [ 1, 3, 3, 1 ],
77                             "type": "UINT8",
78                             "buffer": 0,
79                             "name": "ConvInputTensor",
80                             "quantization": {
81                                 "scale": [ 1.0 ],
82                                 "zero_point": [ 0 ],
83                             }
84                         },
85                         {
86                             "shape": [ 1, 1, 1, 1 ],
87                             "type": "UINT8",
88                             "buffer": 1,
89                             "name": "ConvOutputTensor",
90                             "quantization": {
91                                 "min": [ 0.0 ],
92                                 "max": [ 511.0 ],
93                                 "scale": [ 2.0 ],
94                                 "zero_point": [ 0 ],
95                             }
96                         },
97                         {
98                             "shape": [ 1, 3, 3, 1 ],
99                             "type": "UINT8",
100                             "buffer": 2,
101                             "name": "filterTensor",
102                             "quantization": {
103                                 "min": [ 0.0 ],
104                                 "max": [ 255.0 ],
105                                 "scale": [ 1.0 ],
106                                 "zero_point": [ 0 ],
107                             }
108                         }
109                     ],
110                     "inputs": [ 0 ],
111                     "outputs": [ 1 ],
112                     "operators": [
113                         {
114                             "opcode_index": 0,
115                             "inputs": [ 0, 2 ],
116                             "outputs": [ 1 ],
117                             "builtin_options_type": "Conv2DOptions",
118                             "builtin_options": {
119                                 "padding": "VALID",
120                                 "stride_w": 1,
121                                 "stride_h": 1,
122                                 "fused_activation_function": "NONE"
123                             },
124                             "custom_options_format": "FLEXBUFFERS"
125                         }
126                     ],
127                 }
128             ],
129             "description": "Test Subgraph Inputs Outputs",
130             "buffers" : [
131                     { },
132                     { },
133                     { "data": [ 2,1,0, 6,2,1, 4,1,2 ], },
134                     { },
135                 ]
136         })";
137
138         ReadStringToBinary();
139     }
140
141 };
142
143 struct GetEmptyInputsOutputsFixture : GetInputsOutputsMainFixture
144 {
145     GetEmptyInputsOutputsFixture() : GetInputsOutputsMainFixture("[ ]", "[ ]") {}
146 };
147
148 struct GetInputsOutputsFixture : GetInputsOutputsMainFixture
149 {
150     GetInputsOutputsFixture() : GetInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {}
151 };
152
153 BOOST_FIXTURE_TEST_CASE(GetEmptyInputs, GetEmptyInputsOutputsFixture)
154 {
155     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
156     TfLiteParser::TensorRawPtrVector tensors = TfLiteParser::GetInputs(model, 0, 0);
157     BOOST_CHECK_EQUAL(0, tensors.size());
158 }
159
160 BOOST_FIXTURE_TEST_CASE(GetEmptyOutputs, GetEmptyInputsOutputsFixture)
161 {
162     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
163     TfLiteParser::TensorRawPtrVector tensors = TfLiteParser::GetOutputs(model, 0, 0);
164     BOOST_CHECK_EQUAL(0, tensors.size());
165 }
166
167 BOOST_FIXTURE_TEST_CASE(GetInputs, GetInputsOutputsFixture)
168 {
169     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
170     TfLiteParser::TensorRawPtrVector tensors = TfLiteParser::GetInputs(model, 0, 0);
171     BOOST_CHECK_EQUAL(1, tensors.size());
172     CheckTensors(tensors[0], 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1,
173                       "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 });
174 }
175
176 BOOST_FIXTURE_TEST_CASE(GetOutputs, GetInputsOutputsFixture)
177 {
178     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
179     TfLiteParser::TensorRawPtrVector tensors = TfLiteParser::GetOutputs(model, 0, 0);
180     BOOST_CHECK_EQUAL(1, tensors.size());
181     CheckTensors(tensors[0], 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0,
182                       "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 });
183 }
184
185 BOOST_FIXTURE_TEST_CASE(GetInputsMultipleInputs, GetInputsOutputsFixture)
186 {
187     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
188     TfLiteParser::TensorRawPtrVector tensors = TfLiteParser::GetInputs(model, 1, 0);
189     BOOST_CHECK_EQUAL(2, tensors.size());
190     CheckTensors(tensors[0], 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0,
191                       "ConvInputTensor", { }, { }, { 1.0f }, { 0 });
192     CheckTensors(tensors[1], 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 2,
193                       "filterTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 });
194 }
195
196 BOOST_FIXTURE_TEST_CASE(GetOutputs2, GetInputsOutputsFixture)
197 {
198     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
199     TfLiteParser::TensorRawPtrVector tensors = TfLiteParser::GetOutputs(model, 1, 0);
200     BOOST_CHECK_EQUAL(1, tensors.size());
201     CheckTensors(tensors[0], 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1,
202                       "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 });
203 }
204
205 BOOST_AUTO_TEST_CASE(GetInputsNullModel)
206 {
207     BOOST_CHECK_THROW(TfLiteParser::GetInputs(nullptr, 0, 0), armnn::ParseException);
208 }
209
210 BOOST_AUTO_TEST_CASE(GetOutputsNullModel)
211 {
212     BOOST_CHECK_THROW(TfLiteParser::GetOutputs(nullptr, 0, 0), armnn::ParseException);
213 }
214
215 BOOST_FIXTURE_TEST_CASE(GetInputsInvalidSubgraph, GetInputsOutputsFixture)
216 {
217     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
218     BOOST_CHECK_THROW(TfLiteParser::GetInputs(model, 2, 0), armnn::ParseException);
219 }
220
221 BOOST_FIXTURE_TEST_CASE(GetOutputsInvalidSubgraph, GetInputsOutputsFixture)
222 {
223     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
224     BOOST_CHECK_THROW(TfLiteParser::GetOutputs(model, 2, 0), armnn::ParseException);
225 }
226
227 BOOST_FIXTURE_TEST_CASE(GetInputsInvalidOperator, GetInputsOutputsFixture)
228 {
229     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
230     BOOST_CHECK_THROW(TfLiteParser::GetInputs(model, 0, 1), armnn::ParseException);
231 }
232
233 BOOST_FIXTURE_TEST_CASE(GetOutputsInvalidOperator, GetInputsOutputsFixture)
234 {
235     TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
236     BOOST_CHECK_THROW(TfLiteParser::GetOutputs(model, 0, 1), armnn::ParseException);
237 }
238
239 BOOST_AUTO_TEST_SUITE_END()