Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / tests / tools / onert_train / src / onert_train.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "allocation.h"
18 #include "args.h"
19 #include "benchmark.h"
20 #include "measure.h"
21 #include "nnfw.h"
22 #include "nnfw_util.h"
23 #include "nnfw_internal.h"
24 #include "nnfw_experimental.h"
25 #include "randomgen.h"
26 #include "rawformatter.h"
27 #include "rawdataloader.h"
28
29 #include <boost/program_options.hpp>
30 #include <cassert>
31 #include <chrono>
32 #include <cstdlib>
33 #include <iostream>
34 #include <libgen.h>
35 #include <stdexcept>
36 #include <unordered_map>
37 #include <vector>
38
39 static const char *default_backend_cand = "train";
40
41 int main(const int argc, char **argv)
42 {
43   using namespace onert_train;
44
45   try
46   {
47     Args args(argc, argv);
48     if (args.printVersion())
49     {
50       uint32_t version;
51       NNPR_ENSURE_STATUS(nnfw_query_info_u32(NULL, NNFW_INFO_ID_VERSION, &version));
52       std::cout << "onert_train (nnfw runtime: v" << (version >> 24) << "."
53                 << ((version & 0x0000FF00) >> 8) << "." << (version & 0xFF) << ")" << std::endl;
54       exit(0);
55     }
56
57     // TODO Apply verbose level to phases
58     const int verbose = args.getVerboseLevel();
59     benchmark::Phases phases(benchmark::PhaseOption{});
60
61     nnfw_session *session = nullptr;
62     NNPR_ENSURE_STATUS(nnfw_create_session(&session));
63
64     // ModelLoad
65     phases.run("MODEL_LOAD", [&](const benchmark::Phase &, uint32_t) {
66       if (args.useSingleModel())
67         NNPR_ENSURE_STATUS(
68           nnfw_load_model_from_modelfile(session, args.getModelFilename().c_str()));
69       else
70         NNPR_ENSURE_STATUS(nnfw_load_model_from_file(session, args.getPackageFilename().c_str()));
71     });
72
73     // Set training backend
74     NNPR_ENSURE_STATUS(nnfw_set_available_backends(session, default_backend_cand));
75
76     uint32_t num_inputs;
77     NNPR_ENSURE_STATUS(nnfw_input_size(session, &num_inputs));
78
79     uint32_t num_expecteds;
80     NNPR_ENSURE_STATUS(nnfw_output_size(session, &num_expecteds));
81
82     // verify input and output
83
84     auto verifyInputTypes = [session]() {
85       uint32_t sz;
86       NNPR_ENSURE_STATUS(nnfw_input_size(session, &sz));
87       for (uint32_t i = 0; i < sz; ++i)
88       {
89         nnfw_tensorinfo ti;
90         NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session, i, &ti));
91
92         if (ti.dtype < NNFW_TYPE_TENSOR_FLOAT32 || ti.dtype > NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED)
93         {
94           std::cerr << "E: not supported input type" << std::endl;
95           exit(-1);
96         }
97       }
98     };
99
100     auto verifyOutputTypes = [session]() {
101       uint32_t sz;
102       NNPR_ENSURE_STATUS(nnfw_output_size(session, &sz));
103
104       for (uint32_t i = 0; i < sz; ++i)
105       {
106         nnfw_tensorinfo ti;
107         NNPR_ENSURE_STATUS(nnfw_output_tensorinfo(session, i, &ti));
108
109         if (ti.dtype < NNFW_TYPE_TENSOR_FLOAT32 || ti.dtype > NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED)
110         {
111           std::cerr << "E: not supported output type" << std::endl;
112           exit(-1);
113         }
114       }
115     };
116
117     verifyInputTypes();
118     verifyOutputTypes();
119
120     auto convertLossType = [](int type) {
121       switch (type)
122       {
123         case 0:
124           return NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR;
125         case 1:
126           return NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY;
127         default:
128           std::cerr << "E: not supported loss type" << std::endl;
129           exit(-1);
130       }
131     };
132
133     auto convertOptType = [](int type) {
134       switch (type)
135       {
136         case 0:
137           return NNFW_TRAIN_OPTIMIZER_SGD;
138         case 1:
139           return NNFW_TRAIN_OPTIMIZER_ADAM;
140         default:
141           std::cerr << "E: not supported optimizer type" << std::endl;
142           exit(-1);
143       }
144     };
145
146     // prepare training info
147     nnfw_train_info tri;
148     tri.batch_size = args.getBatchSize();
149     tri.learning_rate = args.getLearningRate();
150     tri.loss = convertLossType(args.getLossType());
151     tri.opt = convertOptType(args.getOptimizerType());
152
153     // prepare execution
154
155     // TODO When nnfw_{prepare|run} are failed, can't catch the time
156     phases.run("PREPARE", [&](const benchmark::Phase &, uint32_t) {
157       NNPR_ENSURE_STATUS(nnfw_train_prepare(session, &tri));
158     });
159
160     // prepare input and expected tensor info lists
161     std::vector<nnfw_tensorinfo> input_infos;
162     std::vector<nnfw_tensorinfo> expected_infos;
163
164     // prepare data buffers
165     std::vector<Allocation> input_data(num_inputs);
166     std::vector<Allocation> expected_data(num_expecteds);
167
168     for (uint32_t i = 0; i < num_inputs; ++i)
169     {
170       nnfw_tensorinfo ti;
171       NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session, i, &ti));
172       input_data[i].alloc(bufsize_for(&ti));
173       input_infos.emplace_back(std::move(ti));
174     }
175
176     for (uint32_t i = 0; i < num_expecteds; ++i)
177     {
178       nnfw_tensorinfo ti;
179       NNPR_ENSURE_STATUS(nnfw_output_tensorinfo(session, i, &ti));
180       expected_data[i].alloc(bufsize_for(&ti));
181       expected_infos.emplace_back(std::move(ti));
182     }
183
184     auto data_length = args.getDataLength();
185
186     Generator generator;
187     RawDataLoader rawDataLoader;
188
189     if (!args.getLoadRawInputFilename().empty() && !args.getLoadRawExpectedFilename().empty())
190     {
191       generator =
192         rawDataLoader.loadData(args.getLoadRawInputFilename(), args.getLoadRawExpectedFilename(),
193                                input_infos, expected_infos, data_length, tri.batch_size);
194     }
195     else
196     {
197       // TODO Use random generator
198       std::cerr << "E: not supported random input and expected generator" << std::endl;
199       exit(-1);
200     }
201
202     Measure measure;
203     std::vector<float> losses(num_expecteds);
204     phases.run("EXECUTE", [&](const benchmark::Phase &, uint32_t) {
205       const int num_step = data_length / tri.batch_size;
206       const int num_epoch = args.getEpoch();
207       measure.set(num_epoch, num_step);
208       for (uint32_t epoch = 0; epoch < num_epoch; ++epoch)
209       {
210         std::fill(losses.begin(), losses.end(), 0);
211         for (uint32_t n = 0; n < num_step; ++n)
212         {
213           // get batchsize data
214           if (!generator(n, input_data, expected_data))
215             break;
216
217           // prepare input
218           for (uint32_t i = 0; i < num_inputs; ++i)
219           {
220             NNPR_ENSURE_STATUS(
221               nnfw_train_set_input(session, i, input_data[i].data(), &input_infos[i]));
222           }
223
224           // prepare output
225           for (uint32_t i = 0; i < num_expecteds; ++i)
226           {
227             NNPR_ENSURE_STATUS(
228               nnfw_train_set_expected(session, i, expected_data[i].data(), &expected_infos[i]));
229           }
230
231           // train
232           measure.run(epoch, n, [&]() { NNPR_ENSURE_STATUS(nnfw_train(session, true)); });
233
234           // store loss
235           for (int32_t i = 0; i < num_expecteds; ++i)
236           {
237             float temp = 0.f;
238             NNPR_ENSURE_STATUS(nnfw_train_get_loss(session, i, &temp));
239             losses[i] += temp;
240           }
241         }
242
243         // print loss
244         std::cout << std::fixed;
245         std::cout.precision(3);
246         std::cout << "Epoch " << epoch + 1 << "/" << num_epoch << " - " << measure.timeMs(epoch)
247                   << "ms/step - loss: ";
248         std::cout.precision(4);
249         for (uint32_t i = 0; i < num_expecteds; ++i)
250         {
251           std::cout << "[" << i << "] " << losses[i] / num_step;
252         }
253         std::cout /* << "- accuracy: " << accuracy*/ << std::endl;
254       }
255     });
256
257     NNPR_ENSURE_STATUS(nnfw_close_session(session));
258
259     // prepare result
260     benchmark::Result result(phases);
261
262     // to stdout
263     benchmark::printResult(result);
264
265     return 0;
266   }
267   catch (boost::program_options::error &e)
268   {
269     std::cerr << "E: " << e.what() << std::endl;
270     exit(-1);
271   }
272   catch (std::runtime_error &e)
273   {
274     std::cerr << "E: Fail to run by runtime error:" << e.what() << std::endl;
275     exit(-1);
276   }
277 }