ce099210b00cc4315ccfb2b9c5562ca7a927f6cf
[platform/core/ml/nnfw.git] / tests / tools / tflite_loader / src / tflite_loader.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "tflite/ext/kernels/register.h"
18
19 #include "args.h"
20 #include "tflite/InterpreterSession.h"
21 #include "tflite/Assert.h"
22 #include "tflite/Diff.h"
23 #include "misc/tensor/IndexIterator.h"
24
25 #include <iostream>
26 #include <fstream>
27
28 #include "compiler/Compiler.h"
29 #include "exec/Execution.h"
30 #include "ir/Graph.h"
31
32 #include "tflite_loader.h"
33
34 #include <memory>
35
36 const int RUN_FAILED = 1;
37
38 using namespace tflite;
39 using namespace nnfw::tflite;
40
41 const int FILE_ERROR = 2;
42 const float DIFFERENCE_THRESHOLD = 10e-5;
43
44 // Read vector of floats from selected file
45 std::vector<float> readData(const string &path)
46 {
47   std::ifstream in(path);
48   if (!in.good())
49   {
50     std::cerr << "can not open data file " << path << "\n";
51     exit(FILE_ERROR);
52   }
53   in.seekg(0, std::ifstream::end);
54   size_t len = in.tellg();
55   in.seekg(0, std::ifstream::beg);
56   assert(len % sizeof(float) == 0);
57   size_t size = len / sizeof(float);
58   std::vector<float> vec(size);
59   for (size_t i = 0; i < size; ++i)
60   {
61     in.read(reinterpret_cast<char *>(&vec[i]), sizeof(float));
62   }
63   return vec;
64 }
65
66 std::vector<float> randomData(nnfw::misc::RandomGenerator &randgen, const uint64_t size)
67 {
68   std::vector<float> vec(size);
69   for (uint64_t i = 0; i < size; i++)
70   {
71     vec[i] = randgen.generate<float>();
72   }
73   return vec;
74 }
75
76 void executeGraph(const std::shared_ptr<onert::ir::Graph> &g,
77                   const std::vector<std::vector<float>> &inputs,
78                   std::vector<std::vector<float>> &outputs)
79 {
80   auto subgs = std::make_shared<onert::ir::Subgraphs>();
81   subgs->push(onert::ir::SubgraphIndex{0}, g);
82   auto compiler = new onert::compiler::Compiler(subgs);
83   std::shared_ptr<onert::exec::ExecutorMap> executors;
84   // Compilation
85   try
86   {
87     executors = compiler->compile();
88   }
89   catch (const std::exception &e)
90   {
91     std::cerr << "[Execution] Can't compile model" << std::endl;
92     std::cerr << e.what() << std::endl;
93     exit(-1);
94   }
95
96   std::cout << "[Execution] Graph compiled!" << std::endl;
97
98   auto execution = std::make_shared<onert::exec::Execution>(executors);
99
100   // Setting IO
101   try
102   {
103     // Verify input shapes
104     auto num_inputs = inputs.size();
105     for (size_t i = 0; i < num_inputs; i++)
106     {
107       auto input_operand_idx = g->getInputs().at(i);
108       auto input_shape = g->operands().at(input_operand_idx).shape();
109       assert(inputs[i].size() == input_shape.num_elements());
110     }
111
112     // Set output shapes
113     auto num_outputs = g->getOutputs().size();
114     outputs.resize(num_outputs);
115     for (uint32_t i = 0; i < num_outputs; i++)
116     {
117       auto output_operand_idx = g->getOutputs().at(i);
118       auto output_shape = g->operands().at(output_operand_idx).shape();
119       outputs[i].resize(output_shape.num_elements());
120     }
121
122     for (size_t i = 0; i < num_inputs; i++)
123       execution->setInput(onert::ir::IOIndex(i), inputs[i].data(),
124                           inputs[i].size() * sizeof(float));
125     for (uint32_t i = 0; i < num_outputs; i++)
126       execution->setOutput(onert::ir::IOIndex(i), outputs[i].data(),
127                            outputs[i].size() * sizeof(float));
128   }
129   catch (const std::exception &e)
130   {
131     std::cerr << "[Execution] Can't set model IO" << std::endl;
132     std::cerr << e.what() << '\n';
133     exit(-1);
134   }
135
136   try
137   {
138     execution->execute();
139   }
140   catch (const std::exception &e)
141   {
142     std::cerr << "[Execution] Can't execute" << std::endl;
143     std::cerr << e.what() << '\n';
144     exit(-1);
145   }
146
147   std::cout << "[Execution] Done!" << std::endl;
148
149   delete compiler;
150 }
151
152 int main(const int argc, char **argv)
153 {
154   TFLiteRun::Args args(argc, argv);
155
156   auto tflite_file = args.getTFLiteFilename();
157   auto data_files = args.getDataFilenames();
158
159   if (tflite_file.empty())
160   {
161     args.print(argv);
162     return RUN_FAILED;
163   }
164
165   std::cout << "[Execution] Stage start!" << std::endl;
166   std::shared_ptr<onert::ir::Graph> test_graph;
167   // Loading
168   try
169   {
170     test_graph =
171         onert::tflite_loader::loadModel(tflite_file.c_str())->at(onert::ir::SubgraphIndex{0});
172   }
173   catch (std::exception &e)
174   {
175     std::cerr << "[ ERROR ] "
176               << "Failure during model load" << std::endl;
177     std::cerr << e.what() << std::endl;
178     exit(-1);
179   }
180
181   // TODO: Support another input/output types
182   for (const auto &input_idx : test_graph->getInputs())
183   {
184     const auto input_type = test_graph->operands().at(input_idx).typeInfo().type();
185     assert(input_type == onert::ir::DataType::FLOAT32 && "Only FLOAT32 inputs are supported");
186   }
187   for (const auto &output_idx : test_graph->getOutputs())
188   {
189     const auto output_type = test_graph->operands().at(output_idx).typeInfo().type();
190     assert(output_type == onert::ir::DataType::FLOAT32 && "Only FLOAT32 outputs are supported");
191   }
192
193   std::cout << "[Execution] Model is deserialized!" << std::endl;
194   auto num_inputs = test_graph->getInputs().size();
195   std::vector<std::vector<float>> inputs(num_inputs);
196   bool generate_data = data_files.empty();
197   bool read_data = data_files.size() == num_inputs;
198   if (num_inputs == 0)
199   {
200     std::cerr << "[ ERROR ] "
201               << "No inputs in model => execution is not possible" << std::endl;
202     exit(1);
203   }
204   if (!generate_data && !read_data)
205   {
206     std::cerr << "[ ERROR ] "
207               << "Wrong number of input files." << std::endl;
208     exit(1);
209   }
210
211   const int seed = 1; /* TODO Add an option for seed value */
212   nnfw::misc::RandomGenerator randgen{seed, 0.0f, 2.0f};
213   try
214   {
215     for (uint32_t i = 0; i < num_inputs; i++)
216     {
217       if (generate_data)
218       {
219         uint64_t sz =
220             test_graph->operands().at(test_graph->getInputs().at(i)).shape().num_elements();
221         inputs[i] = randomData(randgen, sz);
222       }
223       else /* read_data */
224         inputs[i] = readData(data_files[i]);
225     }
226   }
227   catch (std::exception &e)
228   {
229     std::cerr << "[ ERROR ] "
230               << "Failure during input data generation" << std::endl;
231     std::cerr << e.what() << std::endl;
232     exit(-1);
233   }
234
235   std::cout << "[Execution] Input data is defined!" << std::endl;
236   std::vector<std::vector<float>> outputs;
237   // Run graph
238   executeGraph(test_graph, inputs, outputs);
239   // Compare with tflite
240   std::cout << "[Comparison] Stage start!" << std::endl;
241   // Read tflite model
242   StderrReporter error_reporter;
243   auto model = FlatBufferModel::BuildFromFile(tflite_file.c_str(), &error_reporter);
244
245   BuiltinOpResolver resolver;
246   InterpreterBuilder builder(*model, resolver);
247
248   std::unique_ptr<Interpreter> interpreter;
249   try
250   {
251     TFLITE_ENSURE(builder(&interpreter));
252   }
253   catch (const std::exception &e)
254   {
255     std::cerr << e.what() << std::endl;
256     exit(FILE_ERROR);
257   }
258   interpreter->SetNumThreads(2);
259
260   auto sess = std::make_shared<nnfw::tflite::InterpreterSession>(interpreter.get());
261   sess->prepare();
262   // Set input and run
263   for (uint32_t i = 0; i < num_inputs; i++)
264   {
265     auto input_tensor = interpreter->tensor(interpreter->inputs().at(i));
266     memcpy(input_tensor->data.f, inputs[i].data(), inputs[i].size() * sizeof(float));
267   }
268   if (!sess->run())
269   {
270     std::cout << "[Comparison] TFLite run failed!" << std::endl;
271     assert(0 && "Run failed!");
272   }
273   std::cout << "[Comparison] TFLite run done!" << std::endl;
274
275   // Calculate max difference over all outputs
276   float max_difference = 0.0f;
277   auto num_outputs = test_graph->getOutputs().size();
278   for (uint32_t out_idx = 0; out_idx < num_outputs; out_idx++)
279   {
280     const auto &tflite_output_tensor = interpreter->tensor(interpreter->outputs().at(out_idx));
281     const auto &nnfw_output_tensor = outputs[out_idx];
282
283     if (nnfw_output_tensor.size() != tflite_output_tensor->bytes / sizeof(float))
284       std::cout << "[Comparison] Different size of outputs!" << std::endl;
285     // Check max difference
286     float *tflite_out_ptr = tflite_output_tensor->data.f;
287     for (const auto &nnfw_out : nnfw_output_tensor)
288     {
289       if (std::abs(nnfw_out - *tflite_out_ptr) > max_difference)
290         max_difference = std::abs(nnfw_out - *tflite_out_ptr);
291
292       tflite_out_ptr++;
293     }
294   }
295
296   // Print results
297   std::cout << "[Comparison] Max difference: " << max_difference << std::endl;
298   int ret = 0;
299   if (max_difference > DIFFERENCE_THRESHOLD)
300   {
301     std::cout << "[Comparison] Outputs is not equal!" << std::endl;
302     ret = 1;
303   }
304   else
305   {
306     std::cout << "[Comparison] Outputs is equal!" << std::endl;
307   }
308   std::cout << "[Comparison] Done!" << std::endl;
309
310   return ret;
311 }