Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / tests / tools / onert_train / src / args.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "args.h"
18
19 #include <functional>
20 #include <iostream>
21 #include <sys/stat.h>
22 #include <json/json.h>
23
24 namespace
25 {
26
27 // This function parses a json object and returns as a vector of integers
28 // For example,
29 // [0, [1, 2, 3, 4], 3, 40, 4, []] in JSON
30 // is converted to:
31 // {
32 //  0 -> [1, 2, 3, 4]
33 //  3 -> 40
34 //  4 -> []
35 // } in std::unordered_map. Note that the value type is still Json::Value.
36 std::unordered_map<uint32_t, Json::Value> argArrayToMap(const Json::Value &jsonval)
37 {
38   if (!jsonval.isArray() || (jsonval.size() % 2 != 0))
39   {
40     std::cerr << "JSON argument must be an even-sized array in JSON\n";
41     exit(1);
42   }
43
44   std::unordered_map<uint32_t, Json::Value> ret;
45   for (uint32_t i = 0; i < jsonval.size(); i += 2)
46   {
47     if (!jsonval[i].isUInt())
48     {
49       std::cerr << "Key values(values in even indices) must be unsigned integers\n";
50       exit(1);
51     }
52     uint32_t key = jsonval[i].asUInt();
53     Json::Value val = jsonval[i + 1];
54     ret[key] = jsonval[i + 1];
55   }
56   return ret;
57 }
58
59 void checkModelfile(const std::string &model_filename)
60 {
61   if (model_filename.empty())
62   {
63     // TODO Print usage instead of the below message
64     std::cerr << "Please specify model file. Run with `--help` for usage."
65               << "\n";
66
67     exit(1);
68   }
69   else
70   {
71     if (access(model_filename.c_str(), F_OK) == -1)
72     {
73       std::cerr << "Model file not found: " << model_filename << "\n";
74       exit(1);
75     }
76   }
77 }
78
79 void checkPackage(const std::string &package_filename)
80 {
81   if (package_filename.empty())
82   {
83     // TODO Print usage instead of the below message
84     std::cerr << "Please specify nnpackage file. Run with `--help` for usage."
85               << "\n";
86
87     exit(1);
88   }
89   else
90   {
91     if (access(package_filename.c_str(), F_OK) == -1)
92     {
93       std::cerr << "nnpackage not found: " << package_filename << "\n";
94       exit(1);
95     }
96   }
97 }
98
99 } // namespace
100
101 namespace onert_train
102 {
103
104 Args::Args(const int argc, char **argv)
105 {
106   Initialize();
107   Parse(argc, argv);
108 }
109
110 void Args::Initialize(void)
111 {
112   auto process_nnpackage = [&](const std::string &package_filename) {
113     _package_filename = package_filename;
114
115     std::cerr << "Package Filename " << _package_filename << std::endl;
116     checkPackage(package_filename);
117   };
118
119   auto process_modelfile = [&](const std::string &model_filename) {
120     _model_filename = model_filename;
121
122     std::cerr << "Model Filename " << _model_filename << std::endl;
123     checkModelfile(model_filename);
124
125     _use_single_model = true;
126   };
127
128   auto process_path = [&](const std::string &path) {
129     struct stat sb;
130     if (stat(path.c_str(), &sb) == 0)
131     {
132       if (sb.st_mode & S_IFDIR)
133       {
134         _package_filename = path;
135         checkPackage(path);
136         std::cerr << "Package Filename " << path << std::endl;
137       }
138       else
139       {
140         _model_filename = path;
141         checkModelfile(path);
142         std::cerr << "Model Filename " << path << std::endl;
143         _use_single_model = true;
144       }
145     }
146     else
147     {
148       std::cerr << "Cannot find: " << path << "\n";
149       exit(1);
150     }
151   };
152
153   auto process_load_raw_inputfile = [&](const std::string &input_filename) {
154     _load_raw_input_filename = input_filename;
155
156     std::cerr << "Model Input Filename " << _load_raw_input_filename << std::endl;
157     checkModelfile(_load_raw_input_filename);
158   };
159
160   auto process_load_raw_expectedfile = [&](const std::string &expected_filename) {
161     _load_raw_expected_filename = expected_filename;
162
163     std::cerr << "Model Expected Filename " << _load_raw_expected_filename << std::endl;
164     checkModelfile(_load_raw_expected_filename);
165   };
166
167   auto process_output_sizes = [&](const std::string &output_sizes_json_str) {
168     Json::Value root;
169     Json::Reader reader;
170     if (!reader.parse(output_sizes_json_str, root, false))
171     {
172       std::cerr << "Invalid JSON format for output_sizes \"" << output_sizes_json_str << "\"\n";
173       exit(1);
174     }
175
176     auto arg_map = argArrayToMap(root);
177     for (auto &pair : arg_map)
178     {
179       uint32_t key = pair.first;
180       Json::Value &val_json = pair.second;
181       if (!val_json.isUInt())
182       {
183         std::cerr << "All the values in `output_sizes` must be unsigned integers\n";
184         exit(1);
185       }
186       uint32_t val = val_json.asUInt();
187       _output_sizes[key] = val;
188     }
189   };
190
191   // General options
192   po::options_description general("General options", 100);
193
194   // clang-format off
195   general.add_options()
196     ("help,h", "Print available options")
197     ("version", "Print version and exit immediately")
198     ("nnpackage", po::value<std::string>()->notifier(process_nnpackage), "NN Package file(directory) name")
199     ("modelfile", po::value<std::string>()->notifier(process_modelfile), "NN Model filename")
200     ("path", po::value<std::string>()->notifier(process_path), "NN Package or NN Modelfile path")
201     ("data_length", po::value<int>()->default_value(-1)->notifier([&](const auto &v) { _data_length = v; }), "Data length number")
202     ("load_input:raw", po::value<std::string>()->notifier(process_load_raw_inputfile),
203          "NN Model Raw Input data file\n"
204          "The datafile must have data for each input number.\n"
205          "If there are 3 inputs, the data of input0 must exist as much as data_length, "
206          "and the data for input1 and input2 must be held sequentially as data_length.\n"
207     )
208     ("load_expected:raw", po::value<std::string>()->notifier(process_load_raw_expectedfile),
209          "NN Model Raw Expected data file\n"
210          "(Same data policy with load_input:raw)\n"
211     )
212     ("mem_poll,m", po::value<bool>()->default_value(false)->notifier([&](const auto &v) { _mem_poll = v; }), "Check memory polling")
213     ("epoch", po::value<int>()->default_value(5)->notifier([&](const auto &v) { _epoch = v; }), "Epoch number (default: 5)")
214     ("batch_size", po::value<int>()->default_value(32)->notifier([&](const auto &v) { _batch_size = v; }), "Batch size (default: 32)")
215     ("learning_rate", po::value<float>()->default_value(1.0e-4)->notifier([&](const auto &v) { _learning_rate = v; }), "Learning rate (default: 1.0e-4)")
216     ("loss", po::value<int>()->default_value(0)->notifier([&] (const auto &v) { _loss_type = v; }),
217         "Loss type\n"
218         "0: MEAN_SQUARED_ERROR (default)\n"
219         "1: CATEGORICAL_CROSSENTROPY\n")
220     ("optimizer", po::value<int>()->default_value(0)->notifier([&] (const auto &v) { _optimizer_type = v; }),
221       "Optimizer type\n"
222       "0: SGD (default)\n"
223       "1: Adam\n")
224     ("verbose_level,v", po::value<int>()->default_value(0)->notifier([&](const auto &v) { _verbose_level = v; }),
225          "Verbose level\n"
226          "0: prints the only result. Messages btw run don't print\n"
227          "1: prints result and message btw run\n"
228          "2: prints all of messages to print\n")
229     ("output_sizes", po::value<std::string>()->notifier(process_output_sizes),
230         "The output buffer size in JSON 1D array\n"
231         "If not given, the model's output sizes are used\n"
232         "e.g. '[0, 40, 2, 80]' to set 0th tensor to 40 and 2nd tensor to 80.\n")
233     ;
234   // clang-format on
235
236   _options.add(general);
237   _positional.add("path", -1);
238 }
239
240 void Args::Parse(const int argc, char **argv)
241 {
242   po::variables_map vm;
243   po::store(po::command_line_parser(argc, argv).options(_options).positional(_positional).run(),
244             vm);
245
246   if (vm.count("help"))
247   {
248     std::cout << "onert_train\n\n";
249     std::cout << "Usage: " << argv[0] << "[model path] [<options>]\n\n";
250     std::cout << _options;
251     std::cout << "\n";
252
253     exit(0);
254   }
255
256   if (vm.count("version"))
257   {
258     _print_version = true;
259     return;
260   }
261
262   {
263     auto conflicting_options = [&](const std::string &o1, const std::string &o2) {
264       if ((vm.count(o1) && !vm[o1].defaulted()) && (vm.count(o2) && !vm[o2].defaulted()))
265       {
266         throw boost::program_options::error(std::string("Two options '") + o1 + "' and '" + o2 +
267                                             "' cannot be given at once.");
268       }
269     };
270
271     // Cannot use both single model file and nnpackage at once
272     conflicting_options("modelfile", "nnpackage");
273
274     // Require modelfile, nnpackage, or path
275     if (!vm.count("modelfile") && !vm.count("nnpackage") && !vm.count("path"))
276       throw boost::program_options::error(
277         std::string("Require one of options modelfile, nnpackage, or path."));
278   }
279
280   try
281   {
282     po::notify(vm);
283   }
284   catch (const std::bad_cast &e)
285   {
286     std::cerr << "Bad cast error - " << e.what() << '\n';
287     exit(1);
288   }
289 }
290
291 } // end of namespace onert_train