2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "allocation.h"
19 #include "benchmark.h"
22 #include "nnfw_util.h"
23 #include "nnfw_internal.h"
24 #include "nnfw_experimental.h"
25 #include "randomgen.h"
26 #include "rawformatter.h"
27 #include "rawdataloader.h"
29 #include <boost/program_options.hpp>
36 #include <unordered_map>
39 static const char *default_backend_cand = "train";
41 int main(const int argc, char **argv)
43 using namespace onert_train;
47 Args args(argc, argv);
48 if (args.printVersion())
51 NNPR_ENSURE_STATUS(nnfw_query_info_u32(NULL, NNFW_INFO_ID_VERSION, &version));
52 std::cout << "onert_train (nnfw runtime: v" << (version >> 24) << "."
53 << ((version & 0x0000FF00) >> 8) << "." << (version & 0xFF) << ")" << std::endl;
57 // TODO Apply verbose level to phases
58 const int verbose = args.getVerboseLevel();
59 benchmark::Phases phases(benchmark::PhaseOption{});
61 nnfw_session *session = nullptr;
62 NNPR_ENSURE_STATUS(nnfw_create_session(&session));
65 phases.run("MODEL_LOAD", [&](const benchmark::Phase &, uint32_t) {
66 if (args.useSingleModel())
68 nnfw_load_model_from_modelfile(session, args.getModelFilename().c_str()));
70 NNPR_ENSURE_STATUS(nnfw_load_model_from_file(session, args.getPackageFilename().c_str()));
73 // Set training backend
74 NNPR_ENSURE_STATUS(nnfw_set_available_backends(session, default_backend_cand));
77 NNPR_ENSURE_STATUS(nnfw_input_size(session, &num_inputs));
79 uint32_t num_expecteds;
80 NNPR_ENSURE_STATUS(nnfw_output_size(session, &num_expecteds));
82 // verify input and output
84 auto verifyInputTypes = [session]() {
86 NNPR_ENSURE_STATUS(nnfw_input_size(session, &sz));
87 for (uint32_t i = 0; i < sz; ++i)
90 NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session, i, &ti));
92 if (ti.dtype < NNFW_TYPE_TENSOR_FLOAT32 || ti.dtype > NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED)
94 std::cerr << "E: not supported input type" << std::endl;
100 auto verifyOutputTypes = [session]() {
102 NNPR_ENSURE_STATUS(nnfw_output_size(session, &sz));
104 for (uint32_t i = 0; i < sz; ++i)
107 NNPR_ENSURE_STATUS(nnfw_output_tensorinfo(session, i, &ti));
109 if (ti.dtype < NNFW_TYPE_TENSOR_FLOAT32 || ti.dtype > NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED)
111 std::cerr << "E: not supported output type" << std::endl;
120 auto convertLossType = [](int type) {
124 return NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR;
126 return NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY;
128 std::cerr << "E: not supported loss type" << std::endl;
133 auto convertOptType = [](int type) {
137 return NNFW_TRAIN_OPTIMIZER_SGD;
139 return NNFW_TRAIN_OPTIMIZER_ADAM;
141 std::cerr << "E: not supported optimizer type" << std::endl;
146 // prepare training info
148 tri.batch_size = args.getBatchSize();
149 tri.learning_rate = args.getLearningRate();
150 tri.loss = convertLossType(args.getLossType());
151 tri.opt = convertOptType(args.getOptimizerType());
155 // TODO When nnfw_{prepare|run} are failed, can't catch the time
156 phases.run("PREPARE", [&](const benchmark::Phase &, uint32_t) {
157 NNPR_ENSURE_STATUS(nnfw_train_prepare(session, &tri));
160 // prepare input and expected tensor info lists
161 std::vector<nnfw_tensorinfo> input_infos;
162 std::vector<nnfw_tensorinfo> expected_infos;
164 // prepare data buffers
165 std::vector<Allocation> input_data(num_inputs);
166 std::vector<Allocation> expected_data(num_expecteds);
168 for (uint32_t i = 0; i < num_inputs; ++i)
171 NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session, i, &ti));
172 input_data[i].alloc(bufsize_for(&ti));
173 input_infos.emplace_back(std::move(ti));
176 for (uint32_t i = 0; i < num_expecteds; ++i)
179 NNPR_ENSURE_STATUS(nnfw_output_tensorinfo(session, i, &ti));
180 expected_data[i].alloc(bufsize_for(&ti));
181 expected_infos.emplace_back(std::move(ti));
184 auto data_length = args.getDataLength();
187 RawDataLoader rawDataLoader;
189 if (!args.getLoadRawInputFilename().empty() && !args.getLoadRawExpectedFilename().empty())
192 rawDataLoader.loadData(args.getLoadRawInputFilename(), args.getLoadRawExpectedFilename(),
193 input_infos, expected_infos, data_length, tri.batch_size);
197 // TODO Use random generator
198 std::cerr << "E: not supported random input and expected generator" << std::endl;
203 std::vector<float> losses(num_expecteds);
204 phases.run("EXECUTE", [&](const benchmark::Phase &, uint32_t) {
205 const int num_step = data_length / tri.batch_size;
206 const int num_epoch = args.getEpoch();
207 measure.set(num_epoch, num_step);
208 for (uint32_t epoch = 0; epoch < num_epoch; ++epoch)
210 std::fill(losses.begin(), losses.end(), 0);
211 for (uint32_t n = 0; n < num_step; ++n)
213 // get batchsize data
214 if (!generator(n, input_data, expected_data))
218 for (uint32_t i = 0; i < num_inputs; ++i)
221 nnfw_train_set_input(session, i, input_data[i].data(), &input_infos[i]));
225 for (uint32_t i = 0; i < num_expecteds; ++i)
228 nnfw_train_set_expected(session, i, expected_data[i].data(), &expected_infos[i]));
232 measure.run(epoch, n, [&]() { NNPR_ENSURE_STATUS(nnfw_train(session, true)); });
235 for (int32_t i = 0; i < num_expecteds; ++i)
238 NNPR_ENSURE_STATUS(nnfw_train_get_loss(session, i, &temp));
244 std::cout << std::fixed;
245 std::cout.precision(3);
246 std::cout << "Epoch " << epoch + 1 << "/" << num_epoch << " - " << measure.timeMs(epoch)
247 << "ms/step - loss: ";
248 std::cout.precision(4);
249 for (uint32_t i = 0; i < num_expecteds; ++i)
251 std::cout << "[" << i << "] " << losses[i] / num_step;
253 std::cout /* << "- accuracy: " << accuracy*/ << std::endl;
257 NNPR_ENSURE_STATUS(nnfw_close_session(session));
260 benchmark::Result result(phases);
263 benchmark::printResult(result);
267 catch (boost::program_options::error &e)
269 std::cerr << "E: " << e.what() << std::endl;
272 catch (std::runtime_error &e)
274 std::cerr << "E: Fail to run by runtime error:" << e.what() << std::endl;