Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / tests / tools / nnapi_test / src / nnapi_test.cc
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "tflite/ext/kernels/register.h"
18 #include "tensorflow/lite/model.h"
19
20 #include "tflite/interp/FlatBufferBuilder.h"
21 #include "tflite/RandomTestRunner.h"
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "args.h"
27
28 using namespace tflite;
29 using namespace nnfw::tflite;
30 using namespace nnapi_test;
31
32 int main(const int argc, char **argv)
33 {
34   Args args(argc, argv);
35
36   const auto filename = args.getTfliteFilename();
37
38   StderrReporter error_reporter;
39
40   auto model = FlatBufferModel::BuildFromFile(filename.c_str(), &error_reporter);
41
42   if (model == nullptr)
43   {
44     // error_reporter must have shown the error message already
45     return 1;
46   }
47
48   const nnfw::tflite::FlatBufferBuilder builder(*model);
49
50   try
51   {
52     const auto seed = static_cast<uint32_t>(args.getSeed());
53     auto runner = nnfw::tflite::RandomTestRunner::make(seed);
54     const auto num_runs = static_cast<size_t>(args.getNumRuns());
55     runner.compile(builder);
56     return runner.run(num_runs);
57   }
58   catch (const std::exception &e)
59   {
60     std::cerr << e.what() << std::endl;
61     return 1;
62   }
63 }