2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 namespace TFLiteVanillaRun
24 Args::Args(const int argc, char **argv) noexcept
31 catch (const std::exception &e)
33 std::cerr << "error during paring args" << e.what() << '\n';
38 void Args::Initialize(void)
43 po::options_description general("General options");
47 ("help,h", "Display available options")
48 ("input,i", po::value<std::string>()->default_value(""), "Input filename")
49 ("dump,d", po::value<std::string>()->default_value(""), "Output filename")
50 ("ishapes", po::value<std::vector<int>>()->multitoken(), "Input shapes")
51 ("compare,c", po::value<std::string>()->default_value(""), "filename to be compared with")
52 ("tflite", po::value<std::string>()->required())
53 ("num_runs,r", po::value<int>()->default_value(1), "The number of runs")
54 ("warmup_runs,w", po::value<int>()->default_value(0), "The number of warmup runs")
55 ("run_delay,t", po::value<int>()->default_value(-1), "Delay time(ms) between runs (as default no delay")
56 ("gpumem_poll,g", po::value<bool>()->default_value(false), "Check gpu memory polling separately")
57 ("mem_poll,m", po::value<bool>()->default_value(false), "Check memory polling")
58 ("write_report,p", po::value<bool>()->default_value(false), "Write report")
59 ("validate", po::value<bool>()->default_value(true), "Validate tflite model")
60 ("verbose_level,v", po::value<int>()->default_value(0), "Verbose level\n"
61 "0: prints the only result. Messages btw run don't print\n"
62 "1: prints result and message btw run\n"
63 "2: prints all of messages to print\n")
67 _options.add(general);
68 _positional.add("tflite", 1);
70 catch (const std::bad_cast &e)
72 std::cerr << "error by bad cast during initialization of boost::program_options" << e.what()
78 void Args::Parse(const int argc, char **argv)
81 po::store(po::command_line_parser(argc, argv).options(_options).positional(_positional).run(),
85 auto conflicting_options = [&](const std::string &o1, const std::string &o2) {
86 if ((vm.count(o1) && !vm[o1].defaulted()) && (vm.count(o2) && !vm[o2].defaulted()))
88 throw boost::program_options::error(std::string("Two options '") + o1 + "' and '" + o2 +
89 "' cannot be given at once.");
93 conflicting_options("input", "compare");
98 std::cout << "tflite_run\n\n";
99 std::cout << "Usage: " << argv[0] << " <.tflite> [<options>]\n\n";
100 std::cout << _options;
108 if (vm.count("dump"))
110 _dump_filename = vm["dump"].as<std::string>();
113 if (vm.count("compare"))
115 _compare_filename = vm["compare"].as<std::string>();
118 if (vm.count("input"))
120 _input_filename = vm["input"].as<std::string>();
122 if (!_input_filename.empty())
124 if (access(_input_filename.c_str(), F_OK) == -1)
126 std::cerr << "input image file not found: " << _input_filename << "\n";
131 if (vm.count("ishapes"))
133 _input_shapes.resize(vm["ishapes"].as<std::vector<int>>().size());
134 for (auto i = 0; i < _input_shapes.size(); i++)
136 _input_shapes[i] = vm["ishapes"].as<std::vector<int>>()[i];
140 if (vm.count("tflite"))
142 _tflite_filename = vm["tflite"].as<std::string>();
144 if (_tflite_filename.empty())
146 // TODO Print usage instead of the below message
147 std::cerr << "Please specify tflite file. Run with `--help` for usage."
154 if (access(_tflite_filename.c_str(), F_OK) == -1)
156 std::cerr << "tflite file not found: " << _tflite_filename << "\n";
162 if (vm.count("num_runs"))
164 _num_runs = vm["num_runs"].as<int>();
167 if (vm.count("warmup_runs"))
169 _warmup_runs = vm["warmup_runs"].as<int>();
172 if (vm.count("run_delay"))
174 _run_delay = vm["run_delay"].as<int>();
177 if (vm.count("gpumem_poll"))
179 _gpumem_poll = vm["gpumem_poll"].as<bool>();
182 if (vm.count("mem_poll"))
184 _mem_poll = vm["mem_poll"].as<bool>();
185 // Instead of EXECUTE to avoid overhead, memory polling runs on WARMUP
186 if (_mem_poll && _warmup_runs == 0)
192 if (vm.count("write_report"))
194 _write_report = vm["write_report"].as<bool>();
197 if (vm.count("validate"))
199 _tflite_validate = vm["validate"].as<bool>();
202 if (vm.count("verbose_level"))
204 _verbose_level = vm["verbose_level"].as<int>();
208 } // end of namespace TFLiteVanillaRun