Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / tests / tools / onert_train / src / args.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_TRAIN_ARGS_H__
18 #define __ONERT_TRAIN_ARGS_H__
19
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23 #include <boost/program_options.hpp>
24
25 #include "types.h"
26
27 namespace po = boost::program_options;
28
29 namespace onert_train
30 {
31
32 using TensorShapeMap = std::unordered_map<uint32_t, TensorShape>;
33
34 #if defined(ONERT_HAVE_HDF5) && ONERT_HAVE_HDF5 == 1
35 enum class WhenToUseH5Shape
36 {
37   NOT_PROVIDED, // Param not provided
38   PREPARE, // read shapes in h5 file and set them as inputs' shape before calling nnfw_prepare()
39   RUN,     // read shapes in h5 file and set them as inputs' shape before calling nnfw_run()
40 };
41 #endif
42
43 class Args
44 {
45 public:
46   Args(const int argc, char **argv);
47   void print(void);
48
49   const std::string &getPackageFilename(void) const { return _package_filename; }
50   const std::string &getModelFilename(void) const { return _model_filename; }
51   const bool useSingleModel(void) const { return _use_single_model; }
52   const int getDataLength(void) const { return _data_length; }
53   const std::string &getLoadRawInputFilename(void) const { return _load_raw_input_filename; }
54   const std::string &getLoadRawExpectedFilename(void) const { return _load_raw_expected_filename; }
55   const bool getMemoryPoll(void) const { return _mem_poll; }
56   const int getEpoch(void) const { return _epoch; }
57   const int getBatchSize(void) const { return _batch_size; }
58   const float getLearningRate(void) const { return _learning_rate; }
59   const int getLossType(void) const { return _loss_type; }
60   const int getOptimizerType(void) const { return _optimizer_type; }
61   const bool printVersion(void) const { return _print_version; }
62   const int getVerboseLevel(void) const { return _verbose_level; }
63   std::unordered_map<uint32_t, uint32_t> getOutputSizes(void) const { return _output_sizes; }
64
65 private:
66   void Initialize();
67   void Parse(const int argc, char **argv);
68
69 private:
70   po::positional_options_description _positional;
71   po::options_description _options;
72
73   std::string _package_filename;
74   std::string _model_filename;
75   bool _use_single_model = false;
76   int _data_length;
77   std::string _load_raw_input_filename;
78   std::string _load_raw_expected_filename;
79   bool _mem_poll;
80   int _epoch;
81   int _batch_size;
82   float _learning_rate;
83   int _loss_type;
84   int _optimizer_type;
85   bool _print_version = false;
86   int _verbose_level;
87   std::unordered_map<uint32_t, uint32_t> _output_sizes;
88 };
89
90 } // end of namespace onert_train
91
92 #endif // __ONERT_TRAIN_ARGS_H__