Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / tests / tools / nnapi_test / src / args.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "args.h"
18
19 #include <iostream>
20
21 namespace nnapi_test
22 {
23
24 Args::Args(const int argc, char **argv)
25 {
26   Initialize();
27   try
28   {
29     Parse(argc, argv);
30   }
31   catch (const std::exception &e)
32   {
33     std::cerr << "The argments that cannot be parsed: " << e.what() << '\n';
34     print(argv);
35     exit(255);
36   }
37 }
38
39 void Args::print(char **argv)
40 {
41   std::cout << "nnapi_test\n\n";
42   std::cout << "Usage: " << argv[0] << " <.tflite> [<options>]\n\n";
43   std::cout << _options;
44   std::cout << "\n";
45 }
46
47 void Args::Initialize(void)
48 {
49   // General options
50   po::options_description general("General options", 100);
51
52   // clang-format off
53   general.add_options()
54     ("help,h", "Print available options")
55     ("tflite", po::value<std::string>()->required())
56     ("seed", po::value<int>()->default_value(0), "The seed of random inputs")
57     ("num_runs", po::value<int>()->default_value(2), "The number of runs")
58     ;
59   // clang-format on
60
61   _options.add(general);
62   _positional.add("tflite", 1);
63   _positional.add("seed", 2);
64 }
65
66 void Args::Parse(const int argc, char **argv)
67 {
68   po::variables_map vm;
69   po::store(po::command_line_parser(argc, argv).options(_options).positional(_positional).run(),
70             vm);
71
72   if (vm.count("help"))
73   {
74     print(argv);
75
76     exit(0);
77   }
78
79   po::notify(vm);
80   if (vm.count("tflite"))
81   {
82     _tflite_filename = vm["tflite"].as<std::string>();
83
84     if (_tflite_filename.empty())
85     {
86       std::cerr << "Please specify tflite file.\n";
87       print(argv);
88       exit(255);
89     }
90     else
91     {
92       if (access(_tflite_filename.c_str(), F_OK) == -1)
93       {
94         std::cerr << "tflite file not found: " << _tflite_filename << "\n";
95         exit(255);
96       }
97     }
98   }
99
100   if (vm.count("seed"))
101   {
102     _seed = vm["seed"].as<int>();
103   }
104
105   if (vm.count("num_runs"))
106   {
107     _num_runs = vm["num_runs"].as<int>();
108     if (_num_runs < 0)
109     {
110       std::cerr << "num_runs value must be greater than 0.\n";
111       exit(255);
112     }
113   }
114 }
115
116 } // end of namespace nnapi_test