2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __ONERT_TRAIN_ARGS_H__
18 #define __ONERT_TRAIN_ARGS_H__
21 #include <unordered_map>
23 #include <boost/program_options.hpp>
27 namespace po = boost::program_options;
32 using TensorShapeMap = std::unordered_map<uint32_t, TensorShape>;
34 #if defined(ONERT_HAVE_HDF5) && ONERT_HAVE_HDF5 == 1
35 enum class WhenToUseH5Shape
37 NOT_PROVIDED, // Param not provided
38 PREPARE, // read shapes in h5 file and set them as inputs' shape before calling nnfw_prepare()
39 RUN, // read shapes in h5 file and set them as inputs' shape before calling nnfw_run()
46 Args(const int argc, char **argv);
49 const std::string &getPackageFilename(void) const { return _package_filename; }
50 const std::string &getModelFilename(void) const { return _model_filename; }
51 const bool useSingleModel(void) const { return _use_single_model; }
52 const int getDataLength(void) const { return _data_length; }
53 const std::string &getLoadRawInputFilename(void) const { return _load_raw_input_filename; }
54 const std::string &getLoadRawExpectedFilename(void) const { return _load_raw_expected_filename; }
55 const bool getMemoryPoll(void) const { return _mem_poll; }
56 const int getEpoch(void) const { return _epoch; }
57 const int getBatchSize(void) const { return _batch_size; }
58 const float getLearningRate(void) const { return _learning_rate; }
59 const int getLossType(void) const { return _loss_type; }
60 const int getOptimizerType(void) const { return _optimizer_type; }
61 const bool printVersion(void) const { return _print_version; }
62 const int getVerboseLevel(void) const { return _verbose_level; }
63 std::unordered_map<uint32_t, uint32_t> getOutputSizes(void) const { return _output_sizes; }
67 void Parse(const int argc, char **argv);
70 po::positional_options_description _positional;
71 po::options_description _options;
73 std::string _package_filename;
74 std::string _model_filename;
75 bool _use_single_model = false;
77 std::string _load_raw_input_filename;
78 std::string _load_raw_expected_filename;
85 bool _print_version = false;
87 std::unordered_map<uint32_t, uint32_t> _output_sizes;
90 } // end of namespace onert_train
92 #endif // __ONERT_TRAIN_ARGS_H__