2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
22 #include <json/json.h>
27 // This function parses a json object and returns as a vector of integers
29 // [0, [1, 2, 3, 4], 3, 40, 4, []] in JSON
35 // } in std::unordered_map. Note that the value type is still Json::Value.
36 std::unordered_map<uint32_t, Json::Value> argArrayToMap(const Json::Value &jsonval)
38 if (!jsonval.isArray() || (jsonval.size() % 2 != 0))
40 std::cerr << "JSON argument must be an even-sized array in JSON\n";
44 std::unordered_map<uint32_t, Json::Value> ret;
45 for (uint32_t i = 0; i < jsonval.size(); i += 2)
47 if (!jsonval[i].isUInt())
49 std::cerr << "Key values(values in even indices) must be unsigned integers\n";
52 uint32_t key = jsonval[i].asUInt();
53 Json::Value val = jsonval[i + 1];
54 ret[key] = jsonval[i + 1];
59 void checkModelfile(const std::string &model_filename)
61 if (model_filename.empty())
63 // TODO Print usage instead of the below message
64 std::cerr << "Please specify model file. Run with `--help` for usage."
71 if (access(model_filename.c_str(), F_OK) == -1)
73 std::cerr << "Model file not found: " << model_filename << "\n";
79 void checkPackage(const std::string &package_filename)
81 if (package_filename.empty())
83 // TODO Print usage instead of the below message
84 std::cerr << "Please specify nnpackage file. Run with `--help` for usage."
91 if (access(package_filename.c_str(), F_OK) == -1)
93 std::cerr << "nnpackage not found: " << package_filename << "\n";
101 namespace onert_train
104 Args::Args(const int argc, char **argv)
110 void Args::Initialize(void)
112 auto process_nnpackage = [&](const std::string &package_filename) {
113 _package_filename = package_filename;
115 std::cerr << "Package Filename " << _package_filename << std::endl;
116 checkPackage(package_filename);
119 auto process_modelfile = [&](const std::string &model_filename) {
120 _model_filename = model_filename;
122 std::cerr << "Model Filename " << _model_filename << std::endl;
123 checkModelfile(model_filename);
125 _use_single_model = true;
128 auto process_path = [&](const std::string &path) {
130 if (stat(path.c_str(), &sb) == 0)
132 if (sb.st_mode & S_IFDIR)
134 _package_filename = path;
136 std::cerr << "Package Filename " << path << std::endl;
140 _model_filename = path;
141 checkModelfile(path);
142 std::cerr << "Model Filename " << path << std::endl;
143 _use_single_model = true;
148 std::cerr << "Cannot find: " << path << "\n";
153 auto process_load_raw_inputfile = [&](const std::string &input_filename) {
154 _load_raw_input_filename = input_filename;
156 std::cerr << "Model Input Filename " << _load_raw_input_filename << std::endl;
157 checkModelfile(_load_raw_input_filename);
160 auto process_load_raw_expectedfile = [&](const std::string &expected_filename) {
161 _load_raw_expected_filename = expected_filename;
163 std::cerr << "Model Expected Filename " << _load_raw_expected_filename << std::endl;
164 checkModelfile(_load_raw_expected_filename);
167 auto process_output_sizes = [&](const std::string &output_sizes_json_str) {
170 if (!reader.parse(output_sizes_json_str, root, false))
172 std::cerr << "Invalid JSON format for output_sizes \"" << output_sizes_json_str << "\"\n";
176 auto arg_map = argArrayToMap(root);
177 for (auto &pair : arg_map)
179 uint32_t key = pair.first;
180 Json::Value &val_json = pair.second;
181 if (!val_json.isUInt())
183 std::cerr << "All the values in `output_sizes` must be unsigned integers\n";
186 uint32_t val = val_json.asUInt();
187 _output_sizes[key] = val;
192 po::options_description general("General options", 100);
195 general.add_options()
196 ("help,h", "Print available options")
197 ("version", "Print version and exit immediately")
198 ("nnpackage", po::value<std::string>()->notifier(process_nnpackage), "NN Package file(directory) name")
199 ("modelfile", po::value<std::string>()->notifier(process_modelfile), "NN Model filename")
200 ("path", po::value<std::string>()->notifier(process_path), "NN Package or NN Modelfile path")
201 ("data_length", po::value<int>()->default_value(-1)->notifier([&](const auto &v) { _data_length = v; }), "Data length number")
202 ("load_input:raw", po::value<std::string>()->notifier(process_load_raw_inputfile),
203 "NN Model Raw Input data file\n"
204 "The datafile must have data for each input number.\n"
205 "If there are 3 inputs, the data of input0 must exist as much as data_length, "
206 "and the data for input1 and input2 must be held sequentially as data_length.\n"
208 ("load_expected:raw", po::value<std::string>()->notifier(process_load_raw_expectedfile),
209 "NN Model Raw Expected data file\n"
210 "(Same data policy with load_input:raw)\n"
212 ("mem_poll,m", po::value<bool>()->default_value(false)->notifier([&](const auto &v) { _mem_poll = v; }), "Check memory polling")
213 ("epoch", po::value<int>()->default_value(5)->notifier([&](const auto &v) { _epoch = v; }), "Epoch number (default: 5)")
214 ("batch_size", po::value<int>()->default_value(32)->notifier([&](const auto &v) { _batch_size = v; }), "Batch size (default: 32)")
215 ("learning_rate", po::value<float>()->default_value(1.0e-4)->notifier([&](const auto &v) { _learning_rate = v; }), "Learning rate (default: 1.0e-4)")
216 ("loss", po::value<int>()->default_value(0)->notifier([&] (const auto &v) { _loss_type = v; }),
218 "0: MEAN_SQUARED_ERROR (default)\n"
219 "1: CATEGORICAL_CROSSENTROPY\n")
220 ("optimizer", po::value<int>()->default_value(0)->notifier([&] (const auto &v) { _optimizer_type = v; }),
224 ("verbose_level,v", po::value<int>()->default_value(0)->notifier([&](const auto &v) { _verbose_level = v; }),
226 "0: prints the only result. Messages btw run don't print\n"
227 "1: prints result and message btw run\n"
228 "2: prints all of messages to print\n")
229 ("output_sizes", po::value<std::string>()->notifier(process_output_sizes),
230 "The output buffer size in JSON 1D array\n"
231 "If not given, the model's output sizes are used\n"
232 "e.g. '[0, 40, 2, 80]' to set 0th tensor to 40 and 2nd tensor to 80.\n")
236 _options.add(general);
237 _positional.add("path", -1);
240 void Args::Parse(const int argc, char **argv)
242 po::variables_map vm;
243 po::store(po::command_line_parser(argc, argv).options(_options).positional(_positional).run(),
246 if (vm.count("help"))
248 std::cout << "onert_train\n\n";
249 std::cout << "Usage: " << argv[0] << "[model path] [<options>]\n\n";
250 std::cout << _options;
256 if (vm.count("version"))
258 _print_version = true;
263 auto conflicting_options = [&](const std::string &o1, const std::string &o2) {
264 if ((vm.count(o1) && !vm[o1].defaulted()) && (vm.count(o2) && !vm[o2].defaulted()))
266 throw boost::program_options::error(std::string("Two options '") + o1 + "' and '" + o2 +
267 "' cannot be given at once.");
271 // Cannot use both single model file and nnpackage at once
272 conflicting_options("modelfile", "nnpackage");
274 // Require modelfile, nnpackage, or path
275 if (!vm.count("modelfile") && !vm.count("nnpackage") && !vm.count("path"))
276 throw boost::program_options::error(
277 std::string("Require one of options modelfile, nnpackage, or path."));
284 catch (const std::bad_cast &e)
286 std::cerr << "Bad cast error - " << e.what() << '\n';
291 } // end of namespace onert_train