Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / arser / include / arser / arser.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ARSER_H__
18 #define __ARSER_H__
19
20 #include <iostream>
21 #include <sstream>
22
23 #include <iterator>
24 #include <typeinfo>
25
26 #include <algorithm>
27 #include <functional>
28 #include <list>
29 #include <map>
30 #include <string>
31 #include <vector>
32
33 #include <cstring>
34
35 #include <cassert>
36
37 namespace arser
38 {
39 namespace internal
40 {
41
42 template <typename T> T lexical_cast(const std::string &str)
43 {
44   std::istringstream ss;
45   ss.str(str);
46   T data;
47   ss >> data;
48   return data;
49 }
50
51 template <> inline bool lexical_cast(const std::string &str)
52 {
53   bool data = true;
54   if (str == "false" || str == "False" || str == "FALSE" || str == "0")
55     data = false;
56   return data;
57 }
58
59 template <typename T> inline std::string to_string(const T value) { return std::to_string(value); }
60
61 template <> inline std::string to_string(const char *value) { return std::string(value); }
62
63 template <> inline std::string to_string(const bool value) { return value ? "true" : "false"; }
64
65 /**
66  * @brief Returns the string with the leading dash removed.
67  *
68  * If there is no dash, it returns as it is.
69  */
70 inline std::string remove_dash(const std::string &str)
71 {
72   std::string ret{str};
73   auto pos = ret.find_first_not_of('-');
74   if (pos == std::string::npos)
75     return ret;
76   return ret.substr(pos);
77 }
78
79 /**
80  * @brief Returns the string that created by concatenating the elements of a vector with commas.
81  */
82 inline std::string make_comma_concatenated(const std::vector<std::string> &vec)
83 {
84   std::ostringstream oss;
85   std::copy(vec.begin(), std::prev(vec.end()), std::ostream_iterator<std::string>(oss, ", "));
86   oss << vec.back();
87   return oss.str();
88 }
89
90 } // namespace internal
91 } // namespace arser
92
93 namespace arser
94 {
95
96 // TypeName declaration
97 template <typename T> struct TypeName
98 {
99   static const char *Get() { return typeid(T).name(); }
100 };
101 template <> struct TypeName<int>
102 {
103   static const char *Get() { return "int"; }
104 };
105 template <> struct TypeName<std::vector<int>>
106 {
107   static const char *Get() { return "vector<int>"; }
108 };
109 template <> struct TypeName<float>
110 {
111   static const char *Get() { return "float"; }
112 };
113 template <> struct TypeName<std::vector<float>>
114 {
115   static const char *Get() { return "vector<float>"; }
116 };
117 template <> struct TypeName<bool>
118 {
119   static const char *Get() { return "bool"; }
120 };
121 template <> struct TypeName<std::string>
122 {
123   static const char *Get() { return "string"; }
124 };
125 template <> struct TypeName<std::vector<std::string>>
126 {
127   static const char *Get() { return "vector<string>"; }
128 };
129 template <> struct TypeName<const char *>
130 {
131   static const char *Get() { return "string"; }
132 };
133 template <> struct TypeName<std::vector<const char *>>
134 {
135   static const char *Get() { return "vector<string>"; }
136 };
137
138 // supported DataType
139 enum class DataType
140 {
141   INT32,
142   INT32_VEC,
143   FLOAT,
144   FLOAT_VEC,
145   BOOL,
146   STR,
147   STR_VEC,
148 };
149
150 class Arser;
151
152 /**
153  * Argument
154  *   ├── positional argument
155  *   └── optioanl argument  [ dash at the beginning of the string ]
156  *       ├── long option    [ two or more dashes ]
157  *       └── short option   [ one dash ]
158  *
159  * Argument has two types - positional argument, optional argument.
160  *
161  * The way to distinguish the two types is whether there is a dash('-') at the beginning of the
162  * string.
163  *
164  * And, optional argument has two types as well - long option, short option, which is distinguished
165  * by the number of dash.
166  */
167 class Argument
168 {
169 public:
170   explicit Argument(const std::string &arg_name) : _long_name{arg_name}, _names{arg_name} {}
171   explicit Argument(const std::string &short_name, const std::string &long_name)
172     : _short_name{short_name}, _long_name{long_name}, _names{short_name, long_name}
173   {
174   }
175   explicit Argument(const std::string &short_name, const std::string &long_name,
176                     const std::vector<std::string> &names)
177     : _short_name{short_name}, _long_name{long_name}, _names{names}
178   {
179     // 'names' must have 'short_name' and 'long_name'.
180     auto it = std::find(names.begin(), names.end(), short_name);
181     assert(it != names.end());
182     it = std::find(names.begin(), names.end(), long_name);
183     assert(it != names.end());
184     // for avoiding unused warning.
185     (void)it;
186   }
187
188   Argument &nargs(uint32_t num)
189   {
190     if (num == 0)
191     {
192       _type = "bool";
193     }
194     _nargs = num;
195     return *this;
196   }
197
198   Argument &type(DataType type)
199   {
200     switch (type)
201     {
202       case DataType::INT32:
203         _type = "int";
204         break;
205       case DataType::INT32_VEC:
206         _type = "vector<int>";
207         break;
208       case DataType::FLOAT:
209         _type = "float";
210         break;
211       case DataType::FLOAT_VEC:
212         _type = "vector<float>";
213         break;
214       case DataType::BOOL:
215         _type = "bool";
216         break;
217       case DataType::STR:
218         _type = "string";
219         break;
220       case DataType::STR_VEC:
221         _type = "vector<string>";
222         break;
223       default:
224         throw std::runtime_error("NYI DataType");
225     }
226     return *this;
227   }
228
229   Argument &required(void)
230   {
231     _is_required = true;
232     return *this;
233   }
234
235   Argument &required(bool value)
236   {
237     _is_required = value;
238     return *this;
239   }
240
241   Argument &accumulated(void)
242   {
243     _is_accumulated = true;
244     return *this;
245   }
246
247   Argument &accumulated(bool value)
248   {
249     _is_accumulated = value;
250     return *this;
251   }
252
253   Argument &help(std::string help_message)
254   {
255     _help_message = help_message;
256     return *this;
257   }
258
259   Argument &exit_with(const std::function<void(void)> &func)
260   {
261     _func = func;
262     return *this;
263   }
264
265   template <typename T> Argument &default_value(const T value)
266   {
267     if ((_nargs <= 1 && TypeName<T>::Get() == _type) ||
268         (_nargs > 1 && TypeName<std::vector<T>>::Get() == _type))
269       _values.emplace_back(internal::to_string(value));
270     else
271     {
272       throw std::runtime_error("Type mismatch. "
273                                "You called default_value() method with a type different "
274                                "from the one you specified. "
275                                "Please check the type of what you specified in "
276                                "add_argument() method.");
277     }
278     return *this;
279   }
280
281   template <typename T, typename... Ts> Argument &default_value(const T value, const Ts... values)
282   {
283     if ((_nargs <= 1 && TypeName<T>::Get() == _type) ||
284         (_nargs > 1 && TypeName<std::vector<T>>::Get() == _type))
285     {
286       _values.emplace_back(internal::to_string(value));
287       default_value(values...);
288     }
289     else
290     {
291       throw std::runtime_error("Type mismatch. "
292                                "You called default_value() method with a type different "
293                                "from the one you specified. "
294                                "Please check the type of what you specified in "
295                                "add_argument() method.");
296     }
297     return *this;
298   }
299
300 private:
301   // The '_names' vector contains all of the options specified by the user.
302   // And among them, '_long_name' and '_short_name' are selected.
303   std::string _long_name;
304   std::string _short_name;
305   std::vector<std::string> _names;
306   std::string _type;
307   std::string _help_message;
308   std::function<void(void)> _func;
309   uint32_t _nargs{1};
310   bool _is_required{false};
311   bool _is_accumulated{false};
312   std::vector<std::string> _values;
313   std::vector<std::vector<std::string>> _accum_values;
314
315   friend class Arser;
316   friend std::ostream &operator<<(std::ostream &, const Arser &);
317 };
318
319 class Arser
320 {
321 public:
322   explicit Arser(const std::string &program_description = {})
323     : _program_description{program_description}
324   {
325     add_argument("-h", "--help").help("Show help message and exit").nargs(0);
326   }
327
328   Argument &add_argument(const std::string &arg_name)
329   {
330     if (arg_name.at(0) != '-') /* positional */
331     {
332       _positional_arg_vec.emplace_back(arg_name);
333       _arg_map[arg_name] = &_positional_arg_vec.back();
334     }
335     else /* optional */
336     {
337       // The length of optional argument name must be 2 or more.
338       // And it shouldn't be hard to recognize. e.g. '-', '--'
339       if (arg_name.size() < 2)
340       {
341         throw std::runtime_error("Too short name. The length of argument name must be 2 or more.");
342       }
343       if (arg_name == "--")
344       {
345         throw std::runtime_error(
346           "Too short name. Option name must contain at least one character other than dash.");
347       }
348       _optional_arg_vec.emplace_back(arg_name);
349       _optional_arg_vec.back()._short_name = arg_name;
350       _arg_map[arg_name] = &_optional_arg_vec.back();
351     }
352     return *_arg_map[arg_name];
353   }
354
355   Argument &add_argument(const std::vector<std::string> &arg_name_vec)
356   {
357     assert(arg_name_vec.size() >= 2);
358     std::string long_opt, short_opt;
359     // find long and short option
360     for (const auto &arg_name : arg_name_vec)
361     {
362       if (arg_name.at(0) != '-')
363       {
364         throw std::runtime_error("Invalid argument. "
365                                  "Positional argument cannot have short option.");
366       }
367       assert(arg_name.size() >= 2);
368       if (long_opt.empty() && arg_name.at(0) == '-' && arg_name.at(1) == '-')
369       {
370         long_opt = arg_name;
371       }
372       if (short_opt.empty() && arg_name.at(0) == '-' && arg_name.at(1) != '-')
373       {
374         short_opt = arg_name;
375       }
376     }
377     // If one of the two is empty, fill it with the non-empty one for pretty printing.
378     if (long_opt.empty())
379     {
380       assert(not short_opt.empty());
381       long_opt = short_opt;
382     }
383     if (short_opt.empty())
384     {
385       assert(not long_opt.empty());
386       short_opt = long_opt;
387     }
388
389     _optional_arg_vec.emplace_back(short_opt, long_opt, arg_name_vec);
390     for (const auto &arg_name : arg_name_vec)
391     {
392       _arg_map[arg_name] = &_optional_arg_vec.back();
393     }
394     return _optional_arg_vec.back();
395   }
396
397   template <typename... Ts> Argument &add_argument(const std::string &arg_name, Ts... arg_names)
398   {
399     if (sizeof...(arg_names) == 0)
400     {
401       return add_argument(arg_name);
402     }
403     // sizeof...(arg_names) > 0
404     else
405     {
406       return add_argument(std::vector<std::string>{arg_name, arg_names...});
407     }
408   }
409
410   void validate_arguments(void)
411   {
412     // positional argument is always required.
413     for (const auto &arg : _positional_arg_vec)
414     {
415       if (arg._is_required)
416       {
417         throw std::runtime_error("Invalid arguments. Positional argument must always be required.");
418       }
419     }
420     // TODO accumulated arguments shouldn't be enabled to positional arguments.
421     // TODO accumulated arguments shouldn't be enabled to optional arguments whose `narg` == 0.
422   }
423
424   void parse(int argc, char **argv)
425   {
426     validate_arguments();
427     _program_name = argv[0];
428     _program_name.erase(0, _program_name.find_last_of("/\\") + 1);
429     if (argc >= 2)
430     {
431       if (!std::strcmp(argv[1], "--help") || !std::strcmp(argv[1], "-h"))
432       {
433         std::cout << *this;
434         std::exit(0);
435       }
436       else
437       {
438         for (const auto &arg : _arg_map)
439         {
440           const auto &func = arg.second->_func;
441           if (func && !std::strcmp(argv[1], arg.first.c_str()))
442           {
443             func();
444             std::exit(0);
445           }
446         }
447       }
448     }
449     /*
450     ** ./program_name [optional argument] [positional argument]
451     */
452     // get the number of positioanl argument
453     size_t parg_num = _positional_arg_vec.size();
454     // get the number of "required" optional argument
455     size_t required_oarg_num = 0;
456     for (auto arg : _optional_arg_vec)
457     {
458       if (arg._is_required)
459         required_oarg_num++;
460     }
461     // parse argument
462     for (int c = 1; c < argc;)
463     {
464       std::string arg_name{argv[c++]};
465       auto arg = _arg_map.find(arg_name);
466       // check whether arg is positional or not
467       if (arg == _arg_map.end())
468       {
469         if (parg_num)
470         {
471           auto it = _positional_arg_vec.begin();
472           std::advance(it, _positional_arg_vec.size() - parg_num);
473           (*it)._values.clear();
474           (*it)._values.emplace_back(arg_name);
475           parg_num--;
476         }
477         else
478           throw std::runtime_error("Invalid argument. "
479                                    "You've given more positional argument than necessary.");
480       }
481       else // optional argument
482       {
483         // check whether arg is required or not
484         if (arg->second->_is_required)
485           required_oarg_num--;
486         arg->second->_values.clear();
487         for (uint32_t n = 0; n < arg->second->_nargs; n++)
488         {
489           if (c >= argc)
490             throw std::runtime_error("Invalid argument. "
491                                      "You must have missed some argument.");
492           arg->second->_values.emplace_back(argv[c++]);
493         }
494         // accumulate values
495         if (arg->second->_is_accumulated)
496         {
497           arg->second->_accum_values.emplace_back(arg->second->_values);
498         }
499         if (arg->second->_nargs == 0)
500         {
501           // TODO std::boolalpha for true or false
502           arg->second->_values.emplace_back("1");
503         }
504       }
505     }
506     if (parg_num || required_oarg_num)
507       throw std::runtime_error("Invalid argument. "
508                                "You must have missed some argument.");
509   }
510
511   bool operator[](const std::string &arg_name)
512   {
513     auto arg = _arg_map.find(arg_name);
514     if (arg == _arg_map.end())
515       return false;
516
517     if (arg->second->_is_accumulated)
518       return arg->second->_accum_values.size() > 0 ? true : false;
519
520     return arg->second->_values.size() > 0 ? true : false;
521   }
522
523   template <typename T> T get_impl(const std::string &arg_name, T *);
524
525   template <typename T> std::vector<T> get_impl(const std::string &arg_name, std::vector<T> *);
526
527   template <typename T>
528   std::vector<std::vector<T>> get_impl(const std::string &arg_name, std::vector<std::vector<T>> *);
529
530   template <typename T> T get(const std::string &arg_name);
531
532   friend std::ostream &operator<<(std::ostream &stream, const Arser &parser)
533   {
534     // print description
535     if (!parser._program_description.empty())
536     {
537       stream << "What " << parser._program_name << " does: " << parser._program_description
538              << "\n\n";
539     }
540     /*
541     ** print usage
542     */
543     stream << "Usage: ./" << parser._program_name << " ";
544     // required optional argument
545     for (const auto &arg : parser._optional_arg_vec)
546     {
547       if (!arg._is_required)
548         continue;
549       stream << arg._short_name << " ";
550       std::string arg_name = arser::internal::remove_dash(arg._long_name);
551       std::for_each(arg_name.begin(), arg_name.end(),
552                     [&stream](const char &c) { stream << static_cast<char>(::toupper(c)); });
553       stream << " ";
554     }
555     // rest of the optional argument
556     for (const auto &arg : parser._optional_arg_vec)
557     {
558       if (arg._is_required)
559         continue;
560       stream << "[" << arg._short_name;
561       if (arg._nargs)
562       {
563         stream << " ";
564         std::string arg_name = arser::internal::remove_dash(arg._long_name);
565         std::for_each(arg_name.begin(), arg_name.end(),
566                       [&stream](const char &c) { stream << static_cast<char>(::toupper(c)); });
567       }
568       stream << "]"
569              << " ";
570     }
571     // positional arguement
572     for (const auto &arg : parser._positional_arg_vec)
573     {
574       stream << arg._long_name << " ";
575     }
576     stream << "\n\n";
577     /*
578     ** print argument list and its help message
579     */
580     // get the length of the longest argument
581     size_t length_of_longest_arg = 0;
582     for (const auto &arg : parser._positional_arg_vec)
583     {
584       length_of_longest_arg = std::max(length_of_longest_arg,
585                                        arser::internal::make_comma_concatenated(arg._names).size());
586     }
587     for (const auto &arg : parser._optional_arg_vec)
588     {
589       length_of_longest_arg = std::max(length_of_longest_arg,
590                                        arser::internal::make_comma_concatenated(arg._names).size());
591     }
592
593     const size_t message_width = 60;
594     // positional argument
595     if (!parser._positional_arg_vec.empty())
596     {
597       stream << "[Positional argument]" << std::endl;
598       for (const auto &arg : parser._positional_arg_vec)
599       {
600         stream.width(length_of_longest_arg);
601         stream << std::left << arser::internal::make_comma_concatenated(arg._names) << "\t";
602         for (size_t i = 0; i < arg._help_message.length(); i += message_width)
603         {
604           if (i)
605             stream << std::string(length_of_longest_arg, ' ') << "\t";
606           stream << arg._help_message.substr(i, message_width) << std::endl;
607         }
608       }
609       std::cout << std::endl;
610     }
611     // optional argument
612     if (!parser._optional_arg_vec.empty())
613     {
614       stream << "[Optional argument]" << std::endl;
615       for (const auto &arg : parser._optional_arg_vec)
616       {
617         stream.width(length_of_longest_arg);
618         stream << std::left << arser::internal::make_comma_concatenated(arg._names) << "\t";
619         for (size_t i = 0; i < arg._help_message.length(); i += message_width)
620         {
621           if (i)
622             stream << std::string(length_of_longest_arg, ' ') << "\t";
623           stream << arg._help_message.substr(i, message_width) << std::endl;
624         }
625       }
626     }
627
628     return stream;
629   }
630
631 private:
632   std::string _program_name;
633   std::string _program_description;
634   std::list<Argument> _positional_arg_vec;
635   std::list<Argument> _optional_arg_vec;
636   std::map<std::string, Argument *> _arg_map;
637 };
638
639 template <typename T> T Arser::get_impl(const std::string &arg_name, T *)
640 {
641   auto arg = _arg_map.find(arg_name);
642   if (arg == _arg_map.end())
643     throw std::runtime_error("Invalid argument. "
644                              "There is no argument you are looking for: " +
645                              arg_name);
646
647   if (arg->second->_is_accumulated)
648     throw std::runtime_error(
649       "Type mismatch. "
650       "You called get using a type different from the one you specified."
651       "Accumulated argument is returned as std::vector of the specified type");
652
653   if (arg->second->_type != TypeName<T>::Get())
654     throw std::runtime_error("Type mismatch. "
655                              "You called get() method with a type different "
656                              "from the one you specified. "
657                              "Please check the type of what you specified in "
658                              "add_argument() method.");
659
660   if (arg->second->_values.size() == 0)
661     throw std::runtime_error("Wrong access. "
662                              "You must make sure that the argument is given before accessing it. "
663                              "You can do it by calling arser[\"argument\"].");
664
665   return internal::lexical_cast<T>(arg->second->_values[0]);
666 }
667
668 template <typename T> std::vector<T> Arser::get_impl(const std::string &arg_name, std::vector<T> *)
669 {
670   auto arg = _arg_map.find(arg_name);
671   if (arg == _arg_map.end())
672     throw std::runtime_error("Invalid argument. "
673                              "There is no argument you are looking for: " +
674                              arg_name);
675
676   // Accumulated arguments with scalar type (e.g., STR)
677   if (arg->second->_is_accumulated)
678   {
679     if (arg->second->_type != TypeName<T>::Get())
680       throw std::runtime_error("Type mismatch. "
681                                "You called get using a type different from the one you specified.");
682
683     std::vector<T> data;
684     for (auto values : arg->second->_accum_values)
685     {
686       assert(values.size() == 1);
687       data.emplace_back(internal::lexical_cast<T>(values[0]));
688     }
689     return data;
690   }
691
692   if (arg->second->_type != TypeName<std::vector<T>>::Get())
693     throw std::runtime_error("Type mismatch. "
694                              "You called get using a type different from the one you specified.");
695
696   std::vector<T> data;
697   std::transform(arg->second->_values.begin(), arg->second->_values.end(), std::back_inserter(data),
698                  [](std::string str) -> T { return internal::lexical_cast<T>(str); });
699   return data;
700 }
701
702 // Accumulated arguments with vector type (e.g., STR_VEC)
703 template <typename T>
704 std::vector<std::vector<T>> Arser::get_impl(const std::string &arg_name,
705                                             std::vector<std::vector<T>> *)
706 {
707   auto arg = _arg_map.find(arg_name);
708   if (arg == _arg_map.end())
709     throw std::runtime_error("Invalid argument. "
710                              "There is no argument you are looking for: " +
711                              arg_name);
712
713   if (not arg->second->_is_accumulated)
714     throw std::runtime_error("Type mismatch. "
715                              "You called get using a type different from the one you specified.");
716
717   if (arg->second->_type != TypeName<std::vector<T>>::Get())
718     throw std::runtime_error(
719       "Type mismatch. "
720       "You called get using a type different from the one you specified."
721       "Accumulated argument is returned as std::vector of the specified type");
722
723   std::vector<std::vector<T>> result;
724   for (auto values : arg->second->_accum_values)
725   {
726     std::vector<T> data;
727     std::transform(values.begin(), values.end(), std::back_inserter(data),
728                    [](std::string str) -> T { return internal::lexical_cast<T>(str); });
729     result.emplace_back(data);
730   }
731
732   return result;
733 }
734
735 template <typename T> T Arser::get(const std::string &arg_name)
736 {
737   return get_impl(arg_name, static_cast<T *>(nullptr));
738 }
739
740 } // namespace arser
741
742 #endif // __ARSER_H__