Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / arser / include / arser / arser.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <iostream>
18 #include <sstream>
19
20 #include <iterator>
21 #include <typeinfo>
22
23 #include <algorithm>
24 #include <functional>
25 #include <list>
26 #include <map>
27 #include <string>
28 #include <vector>
29
30 #include <cstring>
31
32 namespace
33 {
34
35 template <typename T> T lexical_cast(const std::string &str)
36 {
37   std::istringstream ss;
38   ss.str(str);
39   T data;
40   ss >> data;
41   return data;
42 }
43
44 template <> bool lexical_cast(const std::string &str)
45 {
46   bool data = true;
47   if (str == "false" || str == "False" || str == "FALSE" || str == "0")
48     data = false;
49   return data;
50 }
51
52 template <typename T> inline std::string to_string(const T value) { return std::to_string(value); }
53
54 template <> inline std::string to_string(const char *value) { return std::string(value); }
55
56 template <> inline std::string to_string(const bool value) { return value ? "true" : "false"; }
57
58 } // namespace
59
60 namespace arser
61 {
62
63 // TypeName declaration
64 template <typename T> struct TypeName
65 {
66   static const char *Get() { return typeid(T).name(); }
67 };
68 template <> struct TypeName<int>
69 {
70   static const char *Get() { return "int"; }
71 };
72 template <> struct TypeName<std::vector<int>>
73 {
74   static const char *Get() { return "vector<int>"; }
75 };
76 template <> struct TypeName<float>
77 {
78   static const char *Get() { return "float"; }
79 };
80 template <> struct TypeName<std::vector<float>>
81 {
82   static const char *Get() { return "vector<float>"; }
83 };
84 template <> struct TypeName<bool>
85 {
86   static const char *Get() { return "bool"; }
87 };
88 template <> struct TypeName<std::string>
89 {
90   static const char *Get() { return "string"; }
91 };
92 template <> struct TypeName<std::vector<std::string>>
93 {
94   static const char *Get() { return "vector<string>"; }
95 };
96 template <> struct TypeName<const char *>
97 {
98   static const char *Get() { return "string"; }
99 };
100 template <> struct TypeName<std::vector<const char *>>
101 {
102   static const char *Get() { return "vector<string>"; }
103 };
104
105 // supported DataType
106 enum class DataType
107 {
108   INT32,
109   INT32_VEC,
110   FLOAT,
111   FLOAT_VEC,
112   BOOL,
113   STR,
114   STR_VEC,
115 };
116
117 class Arser;
118
119 class Argument
120 {
121 public:
122   explicit Argument(const std::string &arg_name) : _name{arg_name} {}
123
124   Argument &nargs(uint32_t num)
125   {
126     if (num == 0)
127     {
128       _type = "bool";
129     }
130     _nargs = num;
131     return *this;
132   }
133
134   Argument &type(DataType type)
135   {
136     switch (type)
137     {
138       case DataType::INT32:
139         _type = "int";
140         break;
141       case DataType::INT32_VEC:
142         _type = "vector<int>";
143         break;
144       case DataType::FLOAT:
145         _type = "float";
146         break;
147       case DataType::FLOAT_VEC:
148         _type = "vector<float>";
149         break;
150       case DataType::BOOL:
151         _type = "bool";
152         break;
153       case DataType::STR:
154         _type = "string";
155         break;
156       case DataType::STR_VEC:
157         _type = "vector<string>";
158         break;
159       default:
160         throw std::runtime_error("NYI DataType");
161     }
162     return *this;
163   }
164
165   Argument &required(void)
166   {
167     _is_required = true;
168     return *this;
169   }
170
171   Argument &required(bool value)
172   {
173     _is_required = value;
174     return *this;
175   }
176
177   Argument &help(std::string help_message)
178   {
179     _help_message = help_message;
180     return *this;
181   }
182
183   Argument &exit_with(const std::function<void(void)> &func)
184   {
185     _func = func;
186     return *this;
187   }
188
189   template <typename T> Argument &default_value(const T value)
190   {
191     if ((_nargs <= 1 && TypeName<T>::Get() == _type) ||
192         (_nargs > 1 && TypeName<std::vector<T>>::Get() == _type))
193       _values.emplace_back(::to_string(value));
194     else
195     {
196       throw std::runtime_error("Type mismatch. "
197                                "You called default_value() method with a type different "
198                                "from the one you specified. "
199                                "Please check the type of what you specified in "
200                                "add_argument() method.");
201     }
202     return *this;
203   }
204
205   template <typename T, typename... Ts> Argument &default_value(const T value, const Ts... values)
206   {
207     if ((_nargs <= 1 && TypeName<T>::Get() == _type) ||
208         (_nargs > 1 && TypeName<std::vector<T>>::Get() == _type))
209     {
210       _values.emplace_back(::to_string(value));
211       default_value(values...);
212     }
213     else
214     {
215       throw std::runtime_error("Type mismatch. "
216                                "You called default_value() method with a type different "
217                                "from the one you specified. "
218                                "Please check the type of what you specified in "
219                                "add_argument() method.");
220     }
221     return *this;
222   }
223
224 private:
225   std::string _name;
226   std::string _type;
227   std::string _help_message;
228   std::function<void(void)> _func;
229   uint32_t _nargs{1};
230   bool _is_required{false};
231   std::vector<std::string> _values;
232
233   friend class Arser;
234   friend std::ostream &operator<<(std::ostream &, const Arser &);
235 };
236
237 class Arser
238 {
239 public:
240   explicit Arser(const std::string &program_description = {})
241       : _program_description{program_description}
242   {
243     add_argument("--help").help("Show help message and exit").nargs(0);
244   }
245
246   Argument &add_argument(const std::string &arg_name)
247   {
248     if (arg_name.at(0) != '-')
249     {
250       _positional_arg_vec.emplace_back(arg_name);
251       _arg_map[arg_name] = &_positional_arg_vec.back();
252     }
253     else
254     {
255       _optional_arg_vec.emplace_back(arg_name);
256       _arg_map[arg_name] = &_optional_arg_vec.back();
257     }
258     return *_arg_map[arg_name];
259   }
260
261   void parse(int argc, char **argv)
262   {
263     _program_name = argv[0];
264     _program_name.erase(0, _program_name.find_last_of("/\\") + 1);
265     if (argc >= 2)
266     {
267       if (!std::strcmp(argv[1], "--help"))
268       {
269         std::cout << *this;
270         std::exit(0);
271       }
272       else
273       {
274         for (const auto &arg : _arg_map)
275         {
276           const auto &func = arg.second->_func;
277           if (func && !std::strcmp(argv[1], arg.second->_name.c_str()))
278           {
279             func();
280             std::exit(0);
281           }
282         }
283       }
284     }
285     /*
286     ** ./program_name [optional argument] [positional argument]
287     */
288     // get the number of positioanl argument
289     size_t parg_num = _positional_arg_vec.size();
290     // get the number of "required" optional argument
291     size_t required_oarg_num = 0;
292     for (auto arg : _optional_arg_vec)
293     {
294       if (arg._is_required)
295         required_oarg_num++;
296     }
297     // parse argument
298     for (int c = 1; c < argc;)
299     {
300       std::string arg_name{argv[c++]};
301       auto arg = _arg_map.find(arg_name);
302       // check whether arg is positional or not
303       if (arg == _arg_map.end())
304       {
305         if (parg_num)
306         {
307           auto it = _positional_arg_vec.begin();
308           std::advance(it, _positional_arg_vec.size() - parg_num);
309           (*it)._values.clear();
310           (*it)._values.emplace_back(arg_name);
311           parg_num--;
312         }
313         else
314           throw std::runtime_error("Invalid argument. "
315                                    "You've given more positional argument than necessary.");
316       }
317       else // optional argument
318       {
319         // check whether arg is required or not
320         if (arg->second->_is_required)
321           required_oarg_num--;
322         arg->second->_values.clear();
323         for (uint32_t n = 0; n < arg->second->_nargs; n++)
324         {
325           if (c >= argc)
326             throw std::runtime_error("Invalid argument. "
327                                      "You must have missed some argument.");
328           arg->second->_values.emplace_back(argv[c++]);
329         }
330         if (arg->second->_nargs == 0)
331         {
332           // TODO std::boolalpha for true or false
333           arg->second->_values.emplace_back("1");
334         }
335       }
336     }
337     if (parg_num || required_oarg_num)
338       throw std::runtime_error("Invalid argument. "
339                                "You must have missed some argument.");
340   }
341
342   bool operator[](const std::string &arg_name)
343   {
344     auto arg = _arg_map.find(arg_name);
345     if (arg == _arg_map.end())
346       return false;
347
348     return arg->second->_values.size() > 0 ? true : false;
349   }
350
351   template <typename T> T get_impl(const std::string &arg_name, T *);
352
353   template <typename T> std::vector<T> get_impl(const std::string &arg_name, std::vector<T> *);
354
355   template <typename T> T get(const std::string &arg_name);
356
357 private:
358   std::string _program_name;
359   std::string _program_description;
360   std::list<Argument> _positional_arg_vec;
361   std::list<Argument> _optional_arg_vec;
362   std::map<std::string, Argument *> _arg_map;
363
364   friend std::ostream &operator<<(std::ostream &, const Arser &);
365 };
366
367 template <typename T> T Arser::get_impl(const std::string &arg_name, T *)
368 {
369   auto arg = _arg_map.find(arg_name);
370   if (arg == _arg_map.end())
371     throw std::runtime_error("Invalid argument. "
372                              "There is no argument you are looking for.");
373
374   if (arg->second->_type != TypeName<T>::Get())
375     throw std::runtime_error("Type mismatch. "
376                              "You called get() method with a type different "
377                              "from the one you specified. "
378                              "Please check the type of what you specified in "
379                              "add_argument() method.");
380
381   if (arg->second->_values.size() == 0)
382     throw std::runtime_error("Wrong access. "
383                              "You must make sure that the argument is given before accessing it. "
384                              "You can do it by calling arser[\"argument\"].");
385
386   return ::lexical_cast<T>(arg->second->_values[0]);
387 }
388
389 template <typename T> std::vector<T> Arser::get_impl(const std::string &arg_name, std::vector<T> *)
390 {
391   auto arg = _arg_map.find(arg_name);
392   if (arg == _arg_map.end())
393     throw std::runtime_error("Invalid argument. "
394                              "There is no argument you are looking for.");
395
396   if (arg->second->_type != TypeName<std::vector<T>>::Get())
397     throw std::runtime_error("Type mismatch. "
398                              "You called get using a type different from the one you specified.");
399
400   std::vector<T> data;
401   std::transform(arg->second->_values.begin(), arg->second->_values.end(), std::back_inserter(data),
402                  [](std::string str) -> T { return ::lexical_cast<T>(str); });
403   return data;
404 }
405
406 template <typename T> T Arser::get(const std::string &arg_name)
407 {
408   return get_impl(arg_name, static_cast<T *>(nullptr));
409 }
410
411 std::ostream &operator<<(std::ostream &stream, const Arser &parser)
412 {
413   // print description
414   if (!parser._program_description.empty())
415   {
416     stream << "What " << parser._program_name << " does: " << parser._program_description << "\n\n";
417   }
418   /*
419   ** print usage
420   */
421   stream << "Usage: ./" << parser._program_name << " ";
422   // required optional argument
423   for (const auto &arg : parser._optional_arg_vec)
424   {
425     if (!arg._is_required)
426       continue;
427     stream << arg._name << " ";
428     std::string arg_name = arg._name.substr(2);
429     std::for_each(arg_name.begin(), arg_name.end(),
430                   [&stream](const char &c) { stream << static_cast<char>(::toupper(c)); });
431     stream << " ";
432   }
433   // rest of the optional argument
434   for (const auto &arg : parser._optional_arg_vec)
435   {
436     if (arg._is_required)
437       continue;
438     stream << "[" << arg._name;
439     if (arg._nargs)
440     {
441       stream << " ";
442       std::string arg_name = arg._name.substr(2);
443       std::for_each(arg_name.begin(), arg_name.end(),
444                     [&stream](const char &c) { stream << static_cast<char>(::toupper(c)); });
445     }
446     stream << "]"
447            << " ";
448   }
449   // positional arguement
450   for (const auto &arg : parser._positional_arg_vec)
451   {
452     stream << arg._name << " ";
453   }
454   stream << "\n\n";
455   /*
456   ** print argument list and its help message
457   */
458   // get the length of the longest argument
459   size_t length_of_longest_arg = 0;
460   for (const auto &arg : parser._positional_arg_vec)
461   {
462     length_of_longest_arg = std::max(length_of_longest_arg, arg._name.length());
463   }
464   for (const auto &arg : parser._optional_arg_vec)
465   {
466     length_of_longest_arg = std::max(length_of_longest_arg, arg._name.length());
467   }
468
469   const size_t message_width = 60;
470   // positional argument
471   if (!parser._positional_arg_vec.empty())
472   {
473     stream << "[Positional argument]" << std::endl;
474     for (const auto &arg : parser._positional_arg_vec)
475     {
476       stream.width(length_of_longest_arg);
477       stream << std::left << arg._name << "\t";
478       for (size_t i = 0; i < arg._help_message.length(); i += message_width)
479       {
480         if (i)
481           stream << std::string(length_of_longest_arg, ' ') << "\t";
482         stream << arg._help_message.substr(i, message_width) << std::endl;
483       }
484     }
485     std::cout << std::endl;
486   }
487   // optional argument
488   if (!parser._optional_arg_vec.empty())
489   {
490     stream << "[Optional argument]" << std::endl;
491     for (const auto &arg : parser._optional_arg_vec)
492     {
493       stream.width(length_of_longest_arg);
494       stream << std::left << arg._name << "\t";
495       for (size_t i = 0; i < arg._help_message.length(); i += message_width)
496       {
497         if (i)
498           stream << std::string(length_of_longest_arg, ' ') << "\t";
499         stream << arg._help_message.substr(i, message_width) << std::endl;
500       }
501     }
502   }
503
504   return stream;
505 }
506
507 } // namespace arser