f58ba7c53f0941b67bfcdd3de7c63f935badc5b3
[platform/core/ml/nntrainer.git] / Applications / SimpleShot / task_runner.cpp
1 // SPDX-License-Identifier: Apache-2.0
2 /**
3  * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
4  *
5  * @file   task_runner.cpp
6  * @date   08 Jan 2021
7  * @brief  task runner for the simpleshot demonstration
8  * @see    https://github.com/nnstreamer/nntrainer
9  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10  * @bug    No known bugs except for NYI items
11  */
12
13 #include <fstream>
14 #include <iostream>
15 #include <memory>
16 #include <sstream>
17 #include <string>
18 #include <unistd.h>
19
20 #include <app_context.h>
21 #include <model.h>
22 #include <nntrainer-api-common.h>
23
24 #include "layers/centering.h"
25
26 namespace simpleshot {
27
28 namespace {
29
30 /**
31  * @brief get backbone path from a model name
32  *
33  * @param model resnet50 or conv4  is supported
34  *
35  */
36 const std::string getModelFilePath(const std::string &model,
37                                    const std::string &app_path) {
38   const std::string resnet_model_path =
39     app_path + "/backbones/resnet50_60classes.tflite";
40   const std::string conv4_model_path =
41     app_path + "/backbones/conv4_60classes.tflite";
42
43   std::string model_path;
44
45   if (model == "resnet50") {
46     model_path = resnet_model_path;
47   } else if (model == "conv4") {
48     model_path = conv4_model_path;
49   }
50
51   std::ifstream infile(model_path);
52   if (!infile.good()) {
53     std::stringstream ss;
54     ss << model_path << " as backbone does not exist!";
55     throw std::invalid_argument(ss.str().c_str());
56   }
57
58   if (model_path.empty()) {
59     std::stringstream ss;
60     ss << "not supported model type given, model type: " << model;
61     throw std::invalid_argument(ss.str().c_str());
62   }
63
64   return model_path;
65 }
66
67 const std::string getFeatureFilePath(const std::string &model,
68                                      const std::string &app_path) {
69   const std::string resnet_model_path =
70     app_path + "/backbones/resnet50_60classes_feature_vector.bin";
71   const std::string conv4_model_path =
72     app_path + "/backbones/conv4_60classes_feature_vector.bin";
73
74   std::string model_path;
75
76   if (model == "resnet50") {
77     model_path = resnet_model_path;
78   } else if (model == "conv4") {
79     model_path = conv4_model_path;
80   }
81
82   std::ifstream infile(model_path);
83   if (!infile.good()) {
84     std::stringstream ss;
85     ss << model_path << " as backbone does not exist!";
86     throw std::invalid_argument(ss.str().c_str());
87   }
88
89   if (model_path.empty()) {
90     std::stringstream ss;
91     ss << "not supported model type given, model type: " << model;
92     throw std::invalid_argument(ss.str().c_str());
93   }
94
95   return model_path;
96 }
97
98 /**
99  * @brief get current working directory by cpp string
100  *
101  * @return const std::string current working directory
102  */
103 const std::string getcwd_() {
104   const size_t bufsize = 4096;
105   char buffer[bufsize];
106
107   return getcwd(buffer, bufsize);
108 }
109 } // namespace
110
111 using LayerHandle = std::shared_ptr<ml::train::Layer>;
112
113 /**
114  * @brief Create a Model with given backbone and varient setup
115  *
116  * @param backbone either conv4 or resnet50, hardcoded tflite path will be
117  * selected
118  * @param app_path designated app path to search the backbone file
119  * @param variant "one of UN, L2N, CL2N"
120  * @return std::unique_ptr<ml::train::Model>
121  */
122 std::unique_ptr<ml::train::Model> createModel(const std::string &backbone,
123                                               const std::string &app_path,
124                                               const std::string &variant = "UN",
125                                               const int num_classes = 5) {
126   auto model = ml::train::createModel(ml::train::ModelType::NEURAL_NET,
127                                       {"batch_size=1", "epochs=1"});
128
129   LayerHandle backbone_layer = ml::train::layer::BackboneTFLite(
130     {"name=backbone", "model_path=" + getModelFilePath(backbone, app_path),
131      "input_shape=32:32:3", "trainable=false"});
132   model->addLayer(backbone_layer);
133
134   auto generate_knn_part = [&backbone, &app_path,
135                             num_classes](const std::string &variant_) {
136     std::vector<LayerHandle> v;
137
138     const std::string num_class_prop =
139       "num_class=" + std::to_string(num_classes);
140
141     if (variant_ == "UN") {
142       /// left empty intended
143     } else if (variant_ == "L2N") {
144       LayerHandle l2 = ml::train::createLayer(
145         "preprocess_l2norm", {"name=l2norm", "trainable=false"});
146       v.push_back(l2);
147     } else if (variant_ == "CL2N") {
148       LayerHandle centering = ml::train::createLayer(
149         "centering", {"name=center",
150                       "feature_path=" + getFeatureFilePath(backbone, app_path),
151                       "trainable=false"});
152       LayerHandle l2 = ml::train::createLayer(
153         "preprocess_l2norm", {"name=l2norm", "trainable=false"});
154       v.push_back(centering);
155       v.push_back(l2);
156     } else {
157       std::stringstream ss;
158       ss << "unsupported variant type: " << variant_;
159       throw std::invalid_argument(ss.str().c_str());
160     }
161
162     LayerHandle knn = ml::train::createLayer(
163       "centroid_knn", {"name=knn", num_class_prop, "trainable=false"});
164     v.push_back(knn);
165
166     return v;
167   };
168
169   auto knn_part = generate_knn_part(variant);
170   for (auto &layer : knn_part) {
171     model->addLayer(layer);
172   }
173
174   return model;
175 }
176 } // namespace simpleshot
177
178 /**
179  * @brief main runner
180  *
181  * @return int
182  */
183 int main(int argc, char **argv) {
184   auto &app_context = nntrainer::AppContext::Global();
185
186   if (argc != 6 && argc != 5) {
187     std::cout
188       << "usage: model method train_file validation_file app_path\n"
189       << "model: are [resnet50, conv4]\n"
190       << "methods: are [UN, L2N, CL2N]\n"
191       << "train file: [app_path]/tasks/[train_file] is used for training\n"
192       << "validation file: [app_path]/tasks/[validation_file] is used for "
193          "validation\n"
194       << "app_path: root path to refer to resources, if not given"
195          "path is set current working directory\n";
196     return 1;
197   }
198
199   for (int i = 0; i < argc; ++i) {
200     if (argv[i] == nullptr) {
201       std::cout
202         << "usage: model method train_file_path validation_file_path app_path\n"
203         << "Supported model types are [resnet50, conv4]\n"
204         << "Supported methods are [UN, L2N, CL2N]\n"
205         << "train file: [app_path]/tasks/[train_file] is used for training\n"
206         << "validation file: [app_path]/tasks/[validation_file] is used for "
207            "validation\n"
208         << "app_path: root path to refer to resources, if not given"
209            "path is set current working directory\n";
210       return 1;
211     }
212   }
213
214   std::string model_str(argv[1]);
215   std::string app_path =
216     argc == 6 ? std::string(argv[5]) : simpleshot::getcwd_();
217   std::string method = argv[2];
218   std::string train_path = app_path + "/tasks/" + argv[3];
219   std::string val_path = app_path + "/tasks/" + argv[4];
220
221   try {
222     app_context.registerFactory(
223       nntrainer::createLayer<simpleshot::layers::CenteringLayer>);
224   } catch (std::exception &e) {
225     std::cerr << "registering factory failed: " << e.what();
226     return 1;
227   }
228
229   std::unique_ptr<ml::train::Model> model;
230   try {
231     model = simpleshot::createModel(model_str, app_path, method);
232     model->summarize(std::cout, ML_TRAIN_SUMMARY_MODEL);
233   } catch (std::exception &e) {
234     std::cerr << "creating Model failed: " << e.what();
235     return 1;
236   }
237
238   std::shared_ptr<ml::train::Dataset> train_dataset, valid_dataset;
239   try {
240     train_dataset = ml::train::createDataset(ml::train::DatasetType::FILE,
241                                              train_path.c_str());
242     valid_dataset =
243       ml::train::createDataset(ml::train::DatasetType::FILE, val_path.c_str());
244
245   } catch (...) {
246     std::cerr << "creating dataset failed";
247     return 1;
248   }
249
250   if (model->setDataset(ml::train::DatasetModeType::MODE_TRAIN,
251                         train_dataset)) {
252     std::cerr << "failed to set train dataset" << std::endl;
253     return 1;
254   };
255
256   if (model->setDataset(ml::train::DatasetModeType::MODE_VALID,
257                         valid_dataset)) {
258     std::cerr << "failed to set valid dataset" << std::endl;
259     return 1;
260   };
261
262   std::shared_ptr<ml::train::Optimizer> optimizer;
263   try {
264     optimizer = ml::train::optimizer::SGD({"learning_rate=0.1"});
265   } catch (...) {
266     std::cerr << "creating optimizer failed";
267     return 1;
268   }
269
270   if (model->setOptimizer(optimizer) != 0) {
271     std::cerr << "failed to set optimizer" << std::endl;
272     return 1;
273   }
274
275   if (model->compile() != 0) {
276     std::cerr << "model compilation failed" << std::endl;
277     return 1;
278   }
279
280   if (model->initialize() != 0) {
281     std::cerr << "model initiation failed" << std::endl;
282     return 1;
283   }
284
285   if (model->train() != 0) {
286     std::cerr << "train failed" << std::endl;
287     return 1;
288   }
289
290   std::cout << "successfully ran" << std::endl;
291   return 0;
292 }