Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / tests / tools / tflite_vanilla_run / src / args.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "args.h"
18
19 #include <iostream>
20
21 namespace TFLiteVanillaRun
22 {
23
24 Args::Args(const int argc, char **argv) noexcept
25 {
26   try
27   {
28     Initialize();
29     Parse(argc, argv);
30   }
31   catch (const std::exception &e)
32   {
33     std::cerr << "error during paring args" << e.what() << '\n';
34     exit(1);
35   }
36 }
37
38 void Args::Initialize(void)
39 {
40   try
41   {
42     // General options
43     po::options_description general("General options");
44
45     // clang-format off
46   general.add_options()
47     ("help,h", "Display available options")
48     ("input,i", po::value<std::string>()->default_value(""), "Input filename")
49     ("dump,d", po::value<std::string>()->default_value(""), "Output filename")
50     ("ishapes", po::value<std::vector<int>>()->multitoken(), "Input shapes")
51     ("compare,c", po::value<std::string>()->default_value(""), "filename to be compared with")
52     ("tflite", po::value<std::string>()->required())
53     ("num_runs,r", po::value<int>()->default_value(1), "The number of runs")
54     ("warmup_runs,w", po::value<int>()->default_value(0), "The number of warmup runs")
55     ("run_delay,t", po::value<int>()->default_value(-1), "Delay time(ms) between runs (as default no delay")
56     ("gpumem_poll,g", po::value<bool>()->default_value(false), "Check gpu memory polling separately")
57     ("mem_poll,m", po::value<bool>()->default_value(false), "Check memory polling")
58     ("write_report,p", po::value<bool>()->default_value(false), "Write report")
59     ("validate", po::value<bool>()->default_value(true), "Validate tflite model")
60     ("verbose_level,v", po::value<int>()->default_value(0), "Verbose level\n"
61          "0: prints the only result. Messages btw run don't print\n"
62          "1: prints result and message btw run\n"
63          "2: prints all of messages to print\n")
64     ;
65     // clang-format on
66
67     _options.add(general);
68     _positional.add("tflite", 1);
69   }
70   catch (const std::bad_cast &e)
71   {
72     std::cerr << "error by bad cast during initialization of boost::program_options" << e.what()
73               << '\n';
74     exit(1);
75   }
76 }
77
78 void Args::Parse(const int argc, char **argv)
79 {
80   po::variables_map vm;
81   po::store(po::command_line_parser(argc, argv).options(_options).positional(_positional).run(),
82             vm);
83
84   {
85     auto conflicting_options = [&](const std::string &o1, const std::string &o2) {
86       if ((vm.count(o1) && !vm[o1].defaulted()) && (vm.count(o2) && !vm[o2].defaulted()))
87       {
88         throw boost::program_options::error(std::string("Two options '") + o1 + "' and '" + o2 +
89                                             "' cannot be given at once.");
90       }
91     };
92
93     conflicting_options("input", "compare");
94   }
95
96   if (vm.count("help"))
97   {
98     std::cout << "tflite_run\n\n";
99     std::cout << "Usage: " << argv[0] << " <.tflite> [<options>]\n\n";
100     std::cout << _options;
101     std::cout << "\n";
102
103     exit(0);
104   }
105
106   po::notify(vm);
107
108   if (vm.count("dump"))
109   {
110     _dump_filename = vm["dump"].as<std::string>();
111   }
112
113   if (vm.count("compare"))
114   {
115     _compare_filename = vm["compare"].as<std::string>();
116   }
117
118   if (vm.count("input"))
119   {
120     _input_filename = vm["input"].as<std::string>();
121
122     if (!_input_filename.empty())
123     {
124       if (access(_input_filename.c_str(), F_OK) == -1)
125       {
126         std::cerr << "input image file not found: " << _input_filename << "\n";
127       }
128     }
129   }
130
131   if (vm.count("ishapes"))
132   {
133     _input_shapes.resize(vm["ishapes"].as<std::vector<int>>().size());
134     for (auto i = 0; i < _input_shapes.size(); i++)
135     {
136       _input_shapes[i] = vm["ishapes"].as<std::vector<int>>()[i];
137     }
138   }
139
140   if (vm.count("tflite"))
141   {
142     _tflite_filename = vm["tflite"].as<std::string>();
143
144     if (_tflite_filename.empty())
145     {
146       // TODO Print usage instead of the below message
147       std::cerr << "Please specify tflite file. Run with `--help` for usage."
148                 << "\n";
149
150       exit(1);
151     }
152     else
153     {
154       if (access(_tflite_filename.c_str(), F_OK) == -1)
155       {
156         std::cerr << "tflite file not found: " << _tflite_filename << "\n";
157         exit(1);
158       }
159     }
160   }
161
162   if (vm.count("num_runs"))
163   {
164     _num_runs = vm["num_runs"].as<int>();
165   }
166
167   if (vm.count("warmup_runs"))
168   {
169     _warmup_runs = vm["warmup_runs"].as<int>();
170   }
171
172   if (vm.count("run_delay"))
173   {
174     _run_delay = vm["run_delay"].as<int>();
175   }
176
177   if (vm.count("gpumem_poll"))
178   {
179     _gpumem_poll = vm["gpumem_poll"].as<bool>();
180   }
181
182   if (vm.count("mem_poll"))
183   {
184     _mem_poll = vm["mem_poll"].as<bool>();
185     // Instead of EXECUTE to avoid overhead, memory polling runs on WARMUP
186     if (_mem_poll && _warmup_runs == 0)
187     {
188       _warmup_runs = 1;
189     }
190   }
191
192   if (vm.count("write_report"))
193   {
194     _write_report = vm["write_report"].as<bool>();
195   }
196
197   if (vm.count("validate"))
198   {
199     _tflite_validate = vm["validate"].as<bool>();
200   }
201
202   if (vm.count("verbose_level"))
203   {
204     _verbose_level = vm["verbose_level"].as<int>();
205   }
206 }
207
208 } // end of namespace TFLiteVanillaRun