Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / tests / tools / tflite_loader / src / tflite_loader.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "args.h"
18
19 #include <nnfw_experimental.h>
20 #include <nnfw_internal.h>
21
22 #include <misc/EnvVar.h>
23 #include <misc/RandomGenerator.h>
24
25 #include <tflite/Assert.h>
26 #include <tflite/InterpreterSession.h>
27 #include <tflite/ext/kernels/register.h>
28
29 #include <iostream>
30 #include <fstream>
31 #include <memory>
32
33 const int RUN_FAILED = 1;
34
35 using namespace tflite;
36 using namespace nnfw::tflite;
37
38 const int FILE_ERROR = 2;
39 const float DIFFERENCE_THRESHOLD = 10e-5;
40
41 #define NNFW_ASSERT_FAIL(expr, msg)   \
42   if ((expr) != NNFW_STATUS_NO_ERROR) \
43   {                                   \
44     std::cerr << msg << std::endl;    \
45     exit(-1);                         \
46   }
47
48 // Read vector of floats from selected file
49 void readData(const string &path, std::vector<uint8_t> &dest)
50 {
51   std::ifstream in(path);
52   if (!in.good())
53   {
54     std::cerr << "can not open data file " << path << "\n";
55     exit(FILE_ERROR);
56   }
57   in.seekg(0, std::ifstream::end);
58   size_t len = in.tellg();
59   in.seekg(0, std::ifstream::beg);
60
61   assert(dest.size() == len);
62   in.read(reinterpret_cast<char *>(dest.data()), len);
63 }
64
65 template <typename T>
66 void randomData(nnfw::misc::RandomGenerator &randgen, std::vector<uint8_t> &dest)
67 {
68   size_t elements = dest.size() / sizeof(T);
69   assert(dest.size() % sizeof(T) == 0);
70
71   std::vector<T> vec(elements);
72   for (uint64_t i = 0; i < elements; i++)
73   {
74     vec[i] = randgen.generate<T>();
75   }
76   memcpy(dest.data(), vec.data(), elements * sizeof(T));
77 }
78
79 void randomBoolData(nnfw::misc::RandomGenerator &randgen, std::vector<uint8_t> &dest)
80 {
81   size_t elements = dest.size();
82   std::vector<uint8_t> vec(elements);
83   for (uint64_t i = 0; i < elements; i++)
84   {
85     bool value = randgen.generate<bool>();
86     dest[i] = value ? 1 : 0;
87   }
88 }
89
90 inline uint64_t num_elems(const nnfw_tensorinfo *ti)
91 {
92   uint64_t n = 1;
93   for (uint32_t i = 0; i < ti->rank; ++i)
94   {
95     n *= ti->dims[i];
96   }
97   return n;
98 }
99
100 inline size_t sizeOfNnfwType(NNFW_TYPE type)
101 {
102   switch (type)
103   {
104     case NNFW_TYPE_TENSOR_BOOL:
105     case NNFW_TYPE_TENSOR_UINT8:
106     case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
107     case NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
108       return 1;
109     case NNFW_TYPE_TENSOR_FLOAT32:
110     case NNFW_TYPE_TENSOR_INT32:
111       return 4;
112     case NNFW_TYPE_TENSOR_INT64:
113       return 8;
114     default:
115       throw std::runtime_error{"Invalid tensor type"};
116   }
117 }
118
119 template <typename T>
120 bool compareBuffersExact(const T *ref_buf, const std::vector<uint8_t> &act_buf, uint32_t index)
121 {
122   bool match = true;
123   for (uint32_t e = 0; e < act_buf.size() / sizeof(T); e++)
124   {
125     T ref = ref_buf[e];
126     T act = reinterpret_cast<const T *>(act_buf.data())[e];
127
128     if (ref != act)
129     {
130       std::cerr << "Output #" << index << ", Element Index : " << e << ", ref: " << ref
131                 << ", act: " << act << std::endl;
132       match = false;
133     }
134   }
135
136   return match;
137 }
138
139 bool compareBuffersExactBool(const uint8_t *ref_buf, const std::vector<uint8_t> &act_buf,
140                              uint32_t index)
141 {
142   bool match = true;
143   for (uint32_t e = 0; e < act_buf.size() / sizeof(uint8_t); e++)
144   {
145     uint8_t ref_raw = ref_buf[e];
146     bool ref = (ref_raw != 0 ? true : false);
147     uint8_t act_raw = reinterpret_cast<const uint8_t *>(act_buf.data())[e];
148     bool act = (act_raw != 0 ? true : false);
149     if (ref != act)
150     {
151       std::cerr << "Output #" << index << ", Element Index : " << e << ", ref: " << ref
152                 << ", act: " << act << std::endl;
153       match = false;
154     }
155   }
156
157   return match;
158 }
159
160 int main(const int argc, char **argv)
161 {
162   TFLiteRun::Args args(argc, argv);
163
164   auto tflite_file = args.getTFLiteFilename();
165   auto data_files = args.getDataFilenames();
166
167   if (tflite_file.empty())
168   {
169     args.print(argv);
170     return RUN_FAILED;
171   }
172
173   std::cout << "[Execution] Stage start!" << std::endl;
174   // Loading
175   nnfw_session *onert_session = nullptr;
176   NNFW_ASSERT_FAIL(nnfw_create_session(&onert_session), "[ ERROR ] Failure during model load");
177   if (onert_session == nullptr)
178   {
179     std::cerr << "[ ERROR ] Failure to open session" << std::endl;
180     exit(-1);
181   }
182
183   NNFW_ASSERT_FAIL(nnfw_load_model_from_modelfile(onert_session, tflite_file.c_str()),
184                    "[ ERROR ] Failure during model load");
185
186   uint32_t num_inputs;
187   uint32_t num_outputs;
188   NNFW_ASSERT_FAIL(nnfw_input_size(onert_session, &num_inputs),
189                    "[ ERROR ] Failure during get model inputs");
190   NNFW_ASSERT_FAIL(nnfw_output_size(onert_session, &num_outputs),
191                    "[ ERROR ] Failure during get model outputs");
192
193   std::cout << "[Execution] Model is deserialized!" << std::endl;
194
195   // Compile
196   nnfw_prepare(onert_session);
197
198   std::cout << "[Execution] Model compiled!" << std::endl;
199
200   // Prepare input/output data
201   std::vector<std::vector<uint8_t>> inputs(num_inputs);
202   std::vector<std::vector<uint8_t>> outputs(num_outputs);
203
204   bool generate_data = data_files.empty();
205   bool read_data = data_files.size() == num_inputs;
206   if (!generate_data && !read_data)
207   {
208     std::cerr << "[ ERROR ] "
209               << "Wrong number of input files." << std::endl;
210     exit(1);
211   }
212
213   const int seed = 1; /* TODO Add an option for seed value */
214   nnfw::misc::RandomGenerator randgen{seed, 0.0f, 2.0f};
215
216   for (uint32_t i = 0; i < num_inputs; i++)
217   {
218     nnfw_tensorinfo ti_input;
219     NNFW_ASSERT_FAIL(nnfw_input_tensorinfo(onert_session, i, &ti_input),
220                      "[ ERROR ] Failure during get input data info");
221     size_t input_size = num_elems(&ti_input) * sizeOfNnfwType(ti_input.dtype);
222
223     inputs[i].resize(input_size);
224
225     if (generate_data)
226     {
227       switch (ti_input.dtype)
228       {
229         case NNFW_TYPE_TENSOR_BOOL:
230           randomBoolData(randgen, inputs[i]);
231           break;
232         case NNFW_TYPE_TENSOR_UINT8:
233         case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
234           randomData<uint8_t>(randgen, inputs[i]);
235           break;
236         case NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
237           randomData<int8_t>(randgen, inputs[i]);
238           break;
239         case NNFW_TYPE_TENSOR_FLOAT32:
240           randomData<float>(randgen, inputs[i]);
241           break;
242         case NNFW_TYPE_TENSOR_INT32:
243           randomData<int32_t>(randgen, inputs[i]);
244           break;
245         case NNFW_TYPE_TENSOR_INT64:
246           randomData<uint64_t>(randgen, inputs[i]);
247           break;
248         default:
249           std::cerr << "[ ERROR ] "
250                     << "Unspported input data type" << std::endl;
251           exit(-1);
252           break;
253       }
254     }
255     else /* read_data */
256       readData(data_files[i], inputs[i]);
257
258     NNFW_ASSERT_FAIL(nnfw_set_input(onert_session, i, ti_input.dtype, inputs[i].data(), input_size),
259                      "[ ERROR ] Failure to set input tensor buffer");
260   }
261
262   std::cout << "[Execution] Input data is defined!" << std::endl;
263
264   for (uint32_t i = 0; i < num_outputs; i++)
265   {
266     nnfw_tensorinfo ti_output;
267     NNFW_ASSERT_FAIL(nnfw_output_tensorinfo(onert_session, i, &ti_output),
268                      "[ ERROR ] Failure during get output tensor info");
269
270     uint64_t output_elements = num_elems(&ti_output);
271     size_t output_size = output_elements * sizeOfNnfwType(ti_output.dtype);
272     outputs[i].resize(output_size);
273
274     NNFW_ASSERT_FAIL(
275       nnfw_set_output(onert_session, i, ti_output.dtype, outputs[i].data(), output_size),
276       "[ ERROR ] Failure to set output tensor buffer");
277   }
278
279   // Execute
280   NNFW_ASSERT_FAIL(nnfw_run(onert_session), "[Execution] Can't execute");
281
282   std::cout << "[Execution] Done!" << std::endl;
283
284   // Compare with tflite
285   std::cout << "[Comparison] Stage start!" << std::endl;
286   // Read tflite model
287   StderrReporter error_reporter;
288   auto model = FlatBufferModel::BuildFromFile(tflite_file.c_str(), &error_reporter);
289
290   BuiltinOpResolver resolver;
291   InterpreterBuilder builder(*model, resolver);
292
293   std::unique_ptr<Interpreter> interpreter;
294   try
295   {
296     TFLITE_ENSURE(builder(&interpreter));
297   }
298   catch (const std::exception &e)
299   {
300     std::cerr << e.what() << std::endl;
301     exit(FILE_ERROR);
302   }
303   interpreter->SetNumThreads(nnfw::misc::EnvVar("THREAD").asInt(-1));
304
305   auto sess = std::make_shared<nnfw::tflite::InterpreterSession>(interpreter.get());
306   sess->prepare();
307   // Set input and run
308   for (uint32_t i = 0; i < num_inputs; i++)
309   {
310     auto input_tensor = interpreter->tensor(interpreter->inputs().at(i));
311     memcpy(input_tensor->data.uint8, inputs[i].data(), inputs[i].size());
312   }
313   if (!sess->run())
314   {
315     std::cout << "[Comparison] TFLite run failed!" << std::endl;
316     assert(0 && "Run failed!");
317   }
318   std::cout << "[Comparison] TFLite run done!" << std::endl;
319
320   // Calculate max difference over all outputs
321   float max_float_difference = 0.0f;
322   bool find_unmatched_output = false;
323
324   for (uint32_t out_idx = 0; out_idx < num_outputs; out_idx++)
325   {
326     nnfw_tensorinfo ti;
327     nnfw_output_tensorinfo(onert_session, out_idx, &ti);
328
329     bool matched = true;
330     // Check output tensor values
331
332     const auto &ref_output = interpreter->tensor(interpreter->outputs().at(out_idx))->data;
333     const auto &output = outputs[out_idx];
334
335     switch (ti.dtype)
336     {
337       case NNFW_TYPE_TENSOR_BOOL:
338         matched = compareBuffersExactBool(ref_output.uint8, output, out_idx);
339         break;
340       case NNFW_TYPE_TENSOR_UINT8:
341       case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
342         matched = compareBuffersExact<uint8_t>(ref_output.uint8, output, out_idx);
343         break;
344       case NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
345         matched = compareBuffersExact<int8_t>(ref_output.int8, output, out_idx);
346         break;
347       case NNFW_TYPE_TENSOR_INT32:
348         matched = compareBuffersExact<int32_t>(ref_output.i32, output, out_idx);
349         break;
350       case NNFW_TYPE_TENSOR_FLOAT32:
351         // TODO better way for handling FP error?
352         for (uint32_t e = 0; e < num_elems(&ti); e++)
353         {
354           float refval = ref_output.f[e];
355           float val = reinterpret_cast<const float *>(output.data())[e];
356           if (std::abs(refval - val) > max_float_difference)
357             max_float_difference = std::abs(refval - val);
358
359           if (max_float_difference > DIFFERENCE_THRESHOLD)
360             matched = false;
361         }
362         break;
363       case NNFW_TYPE_TENSOR_INT64:
364         matched = compareBuffersExact<int64_t>(ref_output.i64, output, out_idx);
365         break;
366       default:
367         throw std::runtime_error{"Invalid tensor type"};
368     }
369
370     if (!matched)
371       find_unmatched_output = true;
372   }
373
374   // Print results
375   std::cout << "[Comparison] Max float difference: " << max_float_difference << std::endl;
376   int ret = 0;
377   if (find_unmatched_output)
378   {
379     std::cout << "[Comparison] outputs is not equal!" << std::endl;
380     if (max_float_difference > DIFFERENCE_THRESHOLD)
381     {
382       std::cout << "[Comparison] Float outputs is not equal!" << std::endl;
383     }
384     ret = 1;
385   }
386   else
387   {
388     std::cout << "[Comparison] Outputs is equal!" << std::endl;
389   }
390   std::cout << "[Comparison] Done!" << std::endl;
391
392   nnfw_close_session(onert_session);
393
394   return ret;
395 }