1 // SPDX-License-Identifier: Apache-2.0
3 * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
5 * @file task_runner.cpp
7 * @brief task runner for the simpleshot demonstration
8 * @see https://github.com/nnstreamer/nntrainer
9 * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 * @bug No known bugs except for NYI items
20 #include <app_context.h>
22 #include <nntrainer-api-common.h>
24 #include "layers/centering.h"
26 namespace simpleshot {
31 * @brief get backbone path from a model name
33 * @param model resnet50 or conv4 is supported
36 const std::string getModelFilePath(const std::string &model,
37 const std::string &app_path) {
38 const std::string resnet_model_path =
39 app_path + "/backbones/resnet50_60classes.tflite";
40 const std::string conv4_model_path =
41 app_path + "/backbones/conv4_60classes.tflite";
43 std::string model_path;
45 if (model == "resnet50") {
46 model_path = resnet_model_path;
47 } else if (model == "conv4") {
48 model_path = conv4_model_path;
51 std::ifstream infile(model_path);
54 ss << model_path << " as backbone does not exist!";
55 throw std::invalid_argument(ss.str().c_str());
58 if (model_path.empty()) {
60 ss << "not supported model type given, model type: " << model;
61 throw std::invalid_argument(ss.str().c_str());
67 const std::string getFeatureFilePath(const std::string &model,
68 const std::string &app_path) {
69 const std::string resnet_model_path =
70 app_path + "/backbones/resnet50_60classes_feature_vector.bin";
71 const std::string conv4_model_path =
72 app_path + "/backbones/conv4_60classes_feature_vector.bin";
74 std::string model_path;
76 if (model == "resnet50") {
77 model_path = resnet_model_path;
78 } else if (model == "conv4") {
79 model_path = conv4_model_path;
82 std::ifstream infile(model_path);
85 ss << model_path << " as backbone does not exist!";
86 throw std::invalid_argument(ss.str().c_str());
89 if (model_path.empty()) {
91 ss << "not supported model type given, model type: " << model;
92 throw std::invalid_argument(ss.str().c_str());
99 * @brief get current working directory by cpp string
101 * @return const std::string current working directory
103 const std::string getcwd_() {
104 const size_t bufsize = 4096;
105 char buffer[bufsize];
107 return getcwd(buffer, bufsize);
111 using LayerHandle = std::shared_ptr<ml::train::Layer>;
114 * @brief Create a Model with given backbone and varient setup
116 * @param backbone either conv4 or resnet50, hardcoded tflite path will be
118 * @param app_path designated app path to search the backbone file
119 * @param variant "one of UN, L2N, CL2N"
120 * @return std::unique_ptr<ml::train::Model>
122 std::unique_ptr<ml::train::Model> createModel(const std::string &backbone,
123 const std::string &app_path,
124 const std::string &variant = "UN",
125 const int num_classes = 5) {
126 auto model = ml::train::createModel(ml::train::ModelType::NEURAL_NET,
127 {"batch_size=1", "epochs=1"});
129 LayerHandle backbone_layer = ml::train::layer::BackboneTFLite(
130 {"name=backbone", "model_path=" + getModelFilePath(backbone, app_path),
131 "input_shape=32:32:3", "trainable=false"});
132 model->addLayer(backbone_layer);
134 auto generate_knn_part = [&backbone, &app_path,
135 num_classes](const std::string &variant_) {
136 std::vector<LayerHandle> v;
138 const std::string num_class_prop =
139 "num_class=" + std::to_string(num_classes);
141 if (variant_ == "UN") {
142 /// left empty intended
143 } else if (variant_ == "L2N") {
144 LayerHandle l2 = ml::train::createLayer(
145 "preprocess_l2norm", {"name=l2norm", "trainable=false"});
147 } else if (variant_ == "CL2N") {
148 LayerHandle centering = ml::train::createLayer(
149 "centering", {"name=center",
150 "feature_path=" + getFeatureFilePath(backbone, app_path),
152 LayerHandle l2 = ml::train::createLayer(
153 "preprocess_l2norm", {"name=l2norm", "trainable=false"});
154 v.push_back(centering);
157 std::stringstream ss;
158 ss << "unsupported variant type: " << variant_;
159 throw std::invalid_argument(ss.str().c_str());
162 LayerHandle knn = ml::train::createLayer(
163 "centroid_knn", {"name=knn", num_class_prop, "trainable=false"});
169 auto knn_part = generate_knn_part(variant);
170 for (auto &layer : knn_part) {
171 model->addLayer(layer);
176 } // namespace simpleshot
183 int main(int argc, char **argv) {
184 auto &app_context = nntrainer::AppContext::Global();
186 if (argc != 6 && argc != 5) {
188 << "usage: model method train_file validation_file app_path\n"
189 << "model: are [resnet50, conv4]\n"
190 << "methods: are [UN, L2N, CL2N]\n"
191 << "train file: [app_path]/tasks/[train_file] is used for training\n"
192 << "validation file: [app_path]/tasks/[validation_file] is used for "
194 << "app_path: root path to refer to resources, if not given"
195 "path is set current working directory\n";
199 for (int i = 0; i < argc; ++i) {
200 if (argv[i] == nullptr) {
202 << "usage: model method train_file_path validation_file_path app_path\n"
203 << "Supported model types are [resnet50, conv4]\n"
204 << "Supported methods are [UN, L2N, CL2N]\n"
205 << "train file: [app_path]/tasks/[train_file] is used for training\n"
206 << "validation file: [app_path]/tasks/[validation_file] is used for "
208 << "app_path: root path to refer to resources, if not given"
209 "path is set current working directory\n";
214 std::string model_str(argv[1]);
215 std::string app_path =
216 argc == 6 ? std::string(argv[5]) : simpleshot::getcwd_();
217 std::string method = argv[2];
218 std::string train_path = app_path + "/tasks/" + argv[3];
219 std::string val_path = app_path + "/tasks/" + argv[4];
222 app_context.registerFactory(
223 nntrainer::createLayer<simpleshot::layers::CenteringLayer>);
224 } catch (std::exception &e) {
225 std::cerr << "registering factory failed: " << e.what();
229 std::unique_ptr<ml::train::Model> model;
231 model = simpleshot::createModel(model_str, app_path, method);
232 model->summarize(std::cout, ML_TRAIN_SUMMARY_MODEL);
233 } catch (std::exception &e) {
234 std::cerr << "creating Model failed: " << e.what();
238 std::shared_ptr<ml::train::Dataset> train_dataset, valid_dataset;
240 train_dataset = ml::train::createDataset(ml::train::DatasetType::FILE,
243 ml::train::createDataset(ml::train::DatasetType::FILE, val_path.c_str());
246 std::cerr << "creating dataset failed";
250 if (model->setDataset(ml::train::DatasetModeType::MODE_TRAIN,
252 std::cerr << "failed to set train dataset" << std::endl;
256 if (model->setDataset(ml::train::DatasetModeType::MODE_VALID,
258 std::cerr << "failed to set valid dataset" << std::endl;
262 std::shared_ptr<ml::train::Optimizer> optimizer;
264 optimizer = ml::train::optimizer::SGD({"learning_rate=0.1"});
266 std::cerr << "creating optimizer failed";
270 if (model->setOptimizer(optimizer) != 0) {
271 std::cerr << "failed to set optimizer" << std::endl;
275 if (model->compile() != 0) {
276 std::cerr << "model compilation failed" << std::endl;
280 if (model->initialize() != 0) {
281 std::cerr << "model initiation failed" << std::endl;
285 if (model->train() != 0) {
286 std::cerr << "train failed" << std::endl;
290 std::cout << "successfully ran" << std::endl;