2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 #include <nnfw_experimental.h>
20 #include <nnfw_internal.h>
22 #include <misc/EnvVar.h>
23 #include <misc/RandomGenerator.h>
25 #include <tflite/Assert.h>
26 #include <tflite/InterpreterSession.h>
27 #include <tflite/ext/kernels/register.h>
33 const int RUN_FAILED = 1;
35 using namespace tflite;
36 using namespace nnfw::tflite;
38 const int FILE_ERROR = 2;
39 const float DIFFERENCE_THRESHOLD = 10e-5;
41 #define NNFW_ASSERT_FAIL(expr, msg) \
42 if ((expr) != NNFW_STATUS_NO_ERROR) \
44 std::cerr << msg << std::endl; \
48 // Read vector of floats from selected file
49 void readData(const string &path, std::vector<uint8_t> &dest)
51 std::ifstream in(path);
54 std::cerr << "can not open data file " << path << "\n";
57 in.seekg(0, std::ifstream::end);
58 size_t len = in.tellg();
59 in.seekg(0, std::ifstream::beg);
61 assert(dest.size() == len);
62 in.read(reinterpret_cast<char *>(dest.data()), len);
66 void randomData(nnfw::misc::RandomGenerator &randgen, std::vector<uint8_t> &dest)
68 size_t elements = dest.size() / sizeof(T);
69 assert(dest.size() % sizeof(T) == 0);
71 std::vector<T> vec(elements);
72 for (uint64_t i = 0; i < elements; i++)
74 vec[i] = randgen.generate<T>();
76 memcpy(dest.data(), vec.data(), elements * sizeof(T));
79 void randomBoolData(nnfw::misc::RandomGenerator &randgen, std::vector<uint8_t> &dest)
81 size_t elements = dest.size();
82 std::vector<uint8_t> vec(elements);
83 for (uint64_t i = 0; i < elements; i++)
85 bool value = randgen.generate<bool>();
86 dest[i] = value ? 1 : 0;
90 inline uint64_t num_elems(const nnfw_tensorinfo *ti)
93 for (uint32_t i = 0; i < ti->rank; ++i)
100 inline size_t sizeOfNnfwType(NNFW_TYPE type)
104 case NNFW_TYPE_TENSOR_BOOL:
105 case NNFW_TYPE_TENSOR_UINT8:
106 case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
107 case NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
109 case NNFW_TYPE_TENSOR_FLOAT32:
110 case NNFW_TYPE_TENSOR_INT32:
112 case NNFW_TYPE_TENSOR_INT64:
115 throw std::runtime_error{"Invalid tensor type"};
119 template <typename T>
120 bool compareBuffersExact(const T *ref_buf, const std::vector<uint8_t> &act_buf, uint32_t index)
123 for (uint32_t e = 0; e < act_buf.size() / sizeof(T); e++)
126 T act = reinterpret_cast<const T *>(act_buf.data())[e];
130 std::cerr << "Output #" << index << ", Element Index : " << e << ", ref: " << ref
131 << ", act: " << act << std::endl;
139 bool compareBuffersExactBool(const uint8_t *ref_buf, const std::vector<uint8_t> &act_buf,
143 for (uint32_t e = 0; e < act_buf.size() / sizeof(uint8_t); e++)
145 uint8_t ref_raw = ref_buf[e];
146 bool ref = (ref_raw != 0 ? true : false);
147 uint8_t act_raw = reinterpret_cast<const uint8_t *>(act_buf.data())[e];
148 bool act = (act_raw != 0 ? true : false);
151 std::cerr << "Output #" << index << ", Element Index : " << e << ", ref: " << ref
152 << ", act: " << act << std::endl;
160 int main(const int argc, char **argv)
162 TFLiteRun::Args args(argc, argv);
164 auto tflite_file = args.getTFLiteFilename();
165 auto data_files = args.getDataFilenames();
167 if (tflite_file.empty())
173 std::cout << "[Execution] Stage start!" << std::endl;
175 nnfw_session *onert_session = nullptr;
176 NNFW_ASSERT_FAIL(nnfw_create_session(&onert_session), "[ ERROR ] Failure during model load");
177 if (onert_session == nullptr)
179 std::cerr << "[ ERROR ] Failure to open session" << std::endl;
183 NNFW_ASSERT_FAIL(nnfw_load_model_from_modelfile(onert_session, tflite_file.c_str()),
184 "[ ERROR ] Failure during model load");
187 uint32_t num_outputs;
188 NNFW_ASSERT_FAIL(nnfw_input_size(onert_session, &num_inputs),
189 "[ ERROR ] Failure during get model inputs");
190 NNFW_ASSERT_FAIL(nnfw_output_size(onert_session, &num_outputs),
191 "[ ERROR ] Failure during get model outputs");
193 std::cout << "[Execution] Model is deserialized!" << std::endl;
196 nnfw_prepare(onert_session);
198 std::cout << "[Execution] Model compiled!" << std::endl;
200 // Prepare input/output data
201 std::vector<std::vector<uint8_t>> inputs(num_inputs);
202 std::vector<std::vector<uint8_t>> outputs(num_outputs);
204 bool generate_data = data_files.empty();
205 bool read_data = data_files.size() == num_inputs;
206 if (!generate_data && !read_data)
208 std::cerr << "[ ERROR ] "
209 << "Wrong number of input files." << std::endl;
213 const int seed = 1; /* TODO Add an option for seed value */
214 nnfw::misc::RandomGenerator randgen{seed, 0.0f, 2.0f};
216 for (uint32_t i = 0; i < num_inputs; i++)
218 nnfw_tensorinfo ti_input;
219 NNFW_ASSERT_FAIL(nnfw_input_tensorinfo(onert_session, i, &ti_input),
220 "[ ERROR ] Failure during get input data info");
221 size_t input_size = num_elems(&ti_input) * sizeOfNnfwType(ti_input.dtype);
223 inputs[i].resize(input_size);
227 switch (ti_input.dtype)
229 case NNFW_TYPE_TENSOR_BOOL:
230 randomBoolData(randgen, inputs[i]);
232 case NNFW_TYPE_TENSOR_UINT8:
233 case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
234 randomData<uint8_t>(randgen, inputs[i]);
236 case NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
237 randomData<int8_t>(randgen, inputs[i]);
239 case NNFW_TYPE_TENSOR_FLOAT32:
240 randomData<float>(randgen, inputs[i]);
242 case NNFW_TYPE_TENSOR_INT32:
243 randomData<int32_t>(randgen, inputs[i]);
245 case NNFW_TYPE_TENSOR_INT64:
246 randomData<uint64_t>(randgen, inputs[i]);
249 std::cerr << "[ ERROR ] "
250 << "Unspported input data type" << std::endl;
256 readData(data_files[i], inputs[i]);
258 NNFW_ASSERT_FAIL(nnfw_set_input(onert_session, i, ti_input.dtype, inputs[i].data(), input_size),
259 "[ ERROR ] Failure to set input tensor buffer");
262 std::cout << "[Execution] Input data is defined!" << std::endl;
264 for (uint32_t i = 0; i < num_outputs; i++)
266 nnfw_tensorinfo ti_output;
267 NNFW_ASSERT_FAIL(nnfw_output_tensorinfo(onert_session, i, &ti_output),
268 "[ ERROR ] Failure during get output tensor info");
270 uint64_t output_elements = num_elems(&ti_output);
271 size_t output_size = output_elements * sizeOfNnfwType(ti_output.dtype);
272 outputs[i].resize(output_size);
275 nnfw_set_output(onert_session, i, ti_output.dtype, outputs[i].data(), output_size),
276 "[ ERROR ] Failure to set output tensor buffer");
280 NNFW_ASSERT_FAIL(nnfw_run(onert_session), "[Execution] Can't execute");
282 std::cout << "[Execution] Done!" << std::endl;
284 // Compare with tflite
285 std::cout << "[Comparison] Stage start!" << std::endl;
287 StderrReporter error_reporter;
288 auto model = FlatBufferModel::BuildFromFile(tflite_file.c_str(), &error_reporter);
290 BuiltinOpResolver resolver;
291 InterpreterBuilder builder(*model, resolver);
293 std::unique_ptr<Interpreter> interpreter;
296 TFLITE_ENSURE(builder(&interpreter));
298 catch (const std::exception &e)
300 std::cerr << e.what() << std::endl;
303 interpreter->SetNumThreads(nnfw::misc::EnvVar("THREAD").asInt(-1));
305 auto sess = std::make_shared<nnfw::tflite::InterpreterSession>(interpreter.get());
308 for (uint32_t i = 0; i < num_inputs; i++)
310 auto input_tensor = interpreter->tensor(interpreter->inputs().at(i));
311 memcpy(input_tensor->data.uint8, inputs[i].data(), inputs[i].size());
315 std::cout << "[Comparison] TFLite run failed!" << std::endl;
316 assert(0 && "Run failed!");
318 std::cout << "[Comparison] TFLite run done!" << std::endl;
320 // Calculate max difference over all outputs
321 float max_float_difference = 0.0f;
322 bool find_unmatched_output = false;
324 for (uint32_t out_idx = 0; out_idx < num_outputs; out_idx++)
327 nnfw_output_tensorinfo(onert_session, out_idx, &ti);
330 // Check output tensor values
332 const auto &ref_output = interpreter->tensor(interpreter->outputs().at(out_idx))->data;
333 const auto &output = outputs[out_idx];
337 case NNFW_TYPE_TENSOR_BOOL:
338 matched = compareBuffersExactBool(ref_output.uint8, output, out_idx);
340 case NNFW_TYPE_TENSOR_UINT8:
341 case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
342 matched = compareBuffersExact<uint8_t>(ref_output.uint8, output, out_idx);
344 case NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
345 matched = compareBuffersExact<int8_t>(ref_output.int8, output, out_idx);
347 case NNFW_TYPE_TENSOR_INT32:
348 matched = compareBuffersExact<int32_t>(ref_output.i32, output, out_idx);
350 case NNFW_TYPE_TENSOR_FLOAT32:
351 // TODO better way for handling FP error?
352 for (uint32_t e = 0; e < num_elems(&ti); e++)
354 float refval = ref_output.f[e];
355 float val = reinterpret_cast<const float *>(output.data())[e];
356 if (std::abs(refval - val) > max_float_difference)
357 max_float_difference = std::abs(refval - val);
359 if (max_float_difference > DIFFERENCE_THRESHOLD)
363 case NNFW_TYPE_TENSOR_INT64:
364 matched = compareBuffersExact<int64_t>(ref_output.i64, output, out_idx);
367 throw std::runtime_error{"Invalid tensor type"};
371 find_unmatched_output = true;
375 std::cout << "[Comparison] Max float difference: " << max_float_difference << std::endl;
377 if (find_unmatched_output)
379 std::cout << "[Comparison] outputs is not equal!" << std::endl;
380 if (max_float_difference > DIFFERENCE_THRESHOLD)
382 std::cout << "[Comparison] Float outputs is not equal!" << std::endl;
388 std::cout << "[Comparison] Outputs is equal!" << std::endl;
390 std::cout << "[Comparison] Done!" << std::endl;
392 nnfw_close_session(onert_session);