Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / tests / tools / tflite_vanilla_run / src / args.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __TFLITE_VANILLA_RUN_ARGS_H__
18 #define __TFLITE_VANILLA_RUN_ARGS_H__
19
20 #include <string>
21 #include <boost/program_options.hpp>
22
23 namespace po = boost::program_options;
24
25 namespace TFLiteVanillaRun
26 {
27
28 class Args
29 {
30 public:
31   Args(const int argc, char **argv) noexcept;
32   void print(void);
33
34   const std::string &getTFLiteFilename(void) const { return _tflite_filename; }
35   const std::string &getDumpFilename(void) const { return _dump_filename; }
36   const std::string &getCompareFilename(void) const { return _compare_filename; }
37   const std::string &getInputFilename(void) const { return _input_filename; }
38   const std::vector<int> &getInputShapes(void) const { return _input_shapes; }
39   const int getNumRuns(void) const { return _num_runs; }
40   const int getWarmupRuns(void) const { return _warmup_runs; }
41   const int getRunDelay(void) const { return _run_delay; }
42   const bool getGpuMemoryPoll(void) const { return _gpumem_poll; }
43   const bool getMemoryPoll(void) const { return _mem_poll; }
44   const bool getWriteReport(void) const { return _write_report; }
45   const bool getModelValidate(void) const { return _tflite_validate; }
46   const int getVerboseLevel(void) const { return _verbose_level; }
47
48 private:
49   void Initialize();
50   void Parse(const int argc, char **argv);
51
52 private:
53   po::positional_options_description _positional;
54   po::options_description _options;
55
56   std::string _tflite_filename;
57   std::string _dump_filename;
58   std::string _compare_filename;
59   std::string _input_filename;
60   std::vector<int> _input_shapes;
61   int _num_runs;
62   int _warmup_runs;
63   int _run_delay;
64   bool _gpumem_poll;
65   bool _mem_poll;
66   bool _write_report;
67   bool _tflite_validate;
68   int _verbose_level;
69 };
70
71 } // end of namespace TFLiteVanillaRun
72
73 #endif // __TFLITE_VANILLA_RUN_ARGS_H__