Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / tests / tools / tflite_run / src / args.cc
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "args.h"
18
19 #include <iostream>
20
21 namespace TFLiteRun
22 {
23
24 Args::Args(const int argc, char **argv) noexcept
25 {
26   try
27   {
28     Initialize();
29     Parse(argc, argv);
30   }
31   catch (const std::exception &e)
32   {
33     std::cerr << "error during paring args" << e.what() << '\n';
34     exit(1);
35   }
36 }
37
38 void Args::Initialize(void)
39 {
40   auto process_input = [&](const std::string &v) {
41     _input_filename = v;
42
43     if (!_input_filename.empty())
44     {
45       if (access(_input_filename.c_str(), F_OK) == -1)
46       {
47         std::cerr << "input image file not found: " << _input_filename << "\n";
48       }
49     }
50   };
51
52   auto process_tflite = [&](const std::string &v) {
53     _tflite_filename = v;
54
55     if (_tflite_filename.empty())
56     {
57       // TODO Print usage instead of the below message
58       std::cerr << "Please specify tflite file. Run with `--help` for usage."
59                 << "\n";
60
61       exit(1);
62     }
63     else
64     {
65       if (access(_tflite_filename.c_str(), F_OK) == -1)
66       {
67         std::cerr << "tflite file not found: " << _tflite_filename << "\n";
68         exit(1);
69       }
70     }
71   };
72
73   try
74   {
75     // General options
76     po::options_description general("General options");
77
78     // clang-format off
79   general.add_options()
80     ("help,h", "Display available options")
81     ("input,i", po::value<std::string>()->default_value("")->notifier(process_input), "Input filename")
82     ("dump,d", po::value<std::string>()->default_value("")->notifier([&](const auto &v) { _dump_filename = v; }), "Output filename")
83     ("ishapes", po::value<std::vector<int>>()->multitoken()->notifier([&](const auto &v) { _input_shapes = v; }), "Input shapes")
84     ("compare,c", po::value<std::string>()->default_value("")->notifier([&](const auto &v) { _compare_filename = v; }), "filename to be compared with")
85     ("tflite", po::value<std::string>()->required()->notifier(process_tflite))
86     ("num_runs,r", po::value<int>()->default_value(1)->notifier([&](const auto &v) { _num_runs = v; }), "The number of runs")
87     ("warmup_runs,w", po::value<int>()->default_value(0)->notifier([&](const auto &v) { _warmup_runs = v; }), "The number of warmup runs")
88     ("run_delay,t", po::value<int>()->default_value(-1)->notifier([&](const auto &v) { _run_delay = v; }), "Delay time(ms) between runs (as default no delay)")
89     ("gpumem_poll,g", po::value<bool>()->default_value(false)->notifier([&](const auto &v) { _gpumem_poll = v; }), "Check gpu memory polling separately")
90     ("mem_poll,m", po::value<bool>()->default_value(false), "Check memory polling")
91     ("write_report,p", po::value<bool>()->default_value(false)->notifier([&](const auto &v) { _write_report = v; }), "Write report")
92     ("validate", po::value<bool>()->default_value(true)->notifier([&](const auto &v) { _tflite_validate = v; }), "Validate tflite model")
93     ("verbose_level,v", po::value<int>()->default_value(0)->notifier([&](const auto &v) { _verbose_level = v; }), "Verbose level\n"
94          "0: prints the only result. Messages btw run don't print\n"
95          "1: prints result and message btw run\n"
96          "2: prints all of messages to print\n")
97     ;
98     // clang-format on
99
100     _options.add(general);
101     _positional.add("tflite", 1);
102   }
103   catch (const std::bad_cast &e)
104   {
105     std::cerr << "error by bad cast during initialization of boost::program_options" << e.what()
106               << '\n';
107     exit(1);
108   }
109 }
110
111 void Args::Parse(const int argc, char **argv)
112 {
113   po::variables_map vm;
114   po::store(po::command_line_parser(argc, argv).options(_options).positional(_positional).run(),
115             vm);
116
117   {
118     auto conflicting_options = [&](const std::string &o1, const std::string &o2) {
119       if ((vm.count(o1) && !vm[o1].defaulted()) && (vm.count(o2) && !vm[o2].defaulted()))
120       {
121         throw boost::program_options::error(std::string("Two options '") + o1 + "' and '" + o2 +
122                                             "' cannot be given at once.");
123       }
124     };
125
126     conflicting_options("input", "compare");
127   }
128
129   if (vm.count("help"))
130   {
131     std::cout << "tflite_run\n\n";
132     std::cout << "Usage: " << argv[0] << " <.tflite> [<options>]\n\n";
133     std::cout << _options;
134     std::cout << "\n";
135
136     exit(0);
137   }
138
139   po::notify(vm);
140
141   // This must be run after `notify` as `_warm_up_runs` must have been processed before.
142   if (vm.count("mem_poll"))
143   {
144     _mem_poll = vm["mem_poll"].as<bool>();
145     // Instead of EXECUTE to avoid overhead, memory polling runs on WARMUP
146     if (_mem_poll && _warmup_runs == 0)
147     {
148       _warmup_runs = 1;
149     }
150   }
151 }
152
153 } // end of namespace TFLiteRun