1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
6 * @brief The entry point for Inference Engine validation application
7 * @file validation_app/main.cpp
9 #include <gflags/gflags.h>
23 #include <ext_list.hpp>
25 #include <samples/common.hpp>
26 #include <samples/slog.hpp>
28 #include "user_exception.hpp"
29 #include "ClassificationProcessor.hpp"
30 #include "SSDObjectDetectionProcessor.hpp"
31 #include "YOLOObjectDetectionProcessor.hpp"
34 using namespace InferenceEngine;
36 using InferenceEngine::details::InferenceEngineException;
38 /// @brief Message for help argument
39 static const char help_message[] = "Print a help message";
40 /// @brief Message for images argument
41 static const char image_message[] = "Required. Folder with validation images. Path to a directory with validation images. For Classification models,"
42 " the directory must contain folders named as labels with images inside or a .txt file with"
43 " a list of images. For Object Detection models, the dataset must be in"
45 /// @brief Message for plugin_path argument
46 static const char plugin_path_message[] = "Required. Path to an .xml file with a trained model, including model name and "
48 /// @brief Message for model argument
49 static const char model_message[] = "Required. Path to an .xml file with a trained model";
50 /// @brief Message for plugin argument
51 static const char plugin_message[] = "Plugin name. For example, CPU. If this parameter is passed, "
52 "the sample looks for a specified plugin only.";
53 /// @brief Message for assigning cnn calculation to device
54 static const char target_device_message[] = "Target device to infer on: CPU (default), GPU, FPGA, HDDL or MYRIAD."
55 " The application looks for a suitable plugin for the specified device.";
56 /// @brief Message for label argument
57 static const char label_message[] = "Path to a file with labels for a model";
58 /// @brief Message for batch argumenttype
59 static const char batch_message[] = "Batch size value. If not specified, the batch size value is taken from IR";
60 /// @brief Message for dump argument
61 static const char dump_message[] = "Dump file names and inference results to a .csv file";
62 /// @brief Message for network type
63 static const char type_message[] = "Type of an inferred network (\"C\" by default)";
64 /// @brief Message for pp-type
65 static const char preprocessing_type[] = "Preprocessing type. Options: \"None\", \"Resize\", \"ResizeCrop\"";
66 /// @brief Message for pp-crop-size
67 static const char preprocessing_size[] = "Preprocessing size (used with ppType=\"ResizeCrop\")";
68 static const char preprocessing_width[] = "Preprocessing width (overrides -ppSize, used with ppType=\"ResizeCrop\")";
69 static const char preprocessing_height[] = "Preprocessing height (overrides -ppSize, used with ppType=\"ResizeCrop\")";
71 static const char obj_detection_annotations_message[] = "Required for Object Detection models. Path to a directory"
72 " containing an .xml file with annotations for images.";
74 static const char obj_detection_classes_message[] = "Required for Object Detection models. Path to a file containing"
77 static const char obj_detection_subdir_message[] = "Directory between the path to images (specified with -i) and image name (specified in the"
78 " .xml file). For VOC2007 dataset, use JPEGImages.";
79 static const char obj_detection_kind_message[] = "Type of an Object Detection model. Options: SSD";
81 /// @brief Message for GPU custom kernels desc
82 static const char custom_cldnn_message[] = "Required for GPU custom kernels."
83 "Absolute path to an .xml file with the kernel descriptions.";
85 /// @brief Message for user library argument
86 static const char custom_cpu_library_message[] = "Required for CPU custom layers. "
87 "Absolute path to a shared library with the kernel implementations";
89 /// @brief Message for labels file
90 static const char labels_file_message[] = "Labels file path. The labels file contains names of the dataset classes";
92 static const char zero_background_message[] = "\"Zero is a background\" flag. Some networks are trained with a modified"
93 " dataset where the class IDs "
94 " are enumerated from 1, but 0 is an undefined \"background\" class"
95 " (which is never detected)";
97 static const char plain_output_message[] = "Flag for plain output";
100 /// @brief Network type options and their descriptions
101 static const char* types_descriptions[][2] = {
102 { "C", "classification" },
103 // { "SS", "semantic segmentation" }, // Not supported yet
104 { "OD", "object detection" },
108 /// @brief Define flag for showing help message <br>
109 DEFINE_bool(h, false, help_message);
110 /// @brief Define parameter for a path to images <br>
111 /// It is a required parameter
112 DEFINE_string(i, "", image_message);
113 /// @brief Define parameter for a path to model file <br>
114 /// It is a required parameter
115 DEFINE_string(m, "", model_message);
116 /// @brief Define parameter for a plugin name <br>
117 /// It is a required parameter
118 DEFINE_string(p, "", plugin_message);
119 /// @brief Define parameter for a path to a file with labels <br>
121 DEFINE_string(OCl, "", label_message);
122 /// @brief Define parameter for a path to plugins <br>
124 DEFINE_string(pp, "", plugin_path_message);
125 /// @brief Define parameter for a target device to infer on <br>
126 DEFINE_string(d, "CPU", target_device_message);
127 /// @brief Define parameter for batch size <br>
128 /// Default is 0 (which means that batch size is not specified)
129 DEFINE_int32(b, 0, batch_message);
130 /// @brief Define flag to dump results to a file <br>
131 DEFINE_bool(dump, false, dump_message);
132 /// @brief Define parameter for a network type parameter
133 DEFINE_string(t, "C", type_message);
135 /// @brief Define parameter for preprocessing type
136 DEFINE_string(ppType, "", preprocessing_type);
138 /// @brief Define parameter for preprocessing size
139 DEFINE_int32(ppSize, 0, preprocessing_size);
140 DEFINE_int32(ppWidth, 0, preprocessing_width);
141 DEFINE_int32(ppHeight, 0, preprocessing_height);
143 DEFINE_bool(Czb, false, zero_background_message);
145 DEFINE_string(ODa, "", obj_detection_annotations_message);
147 DEFINE_string(ODc, "", obj_detection_classes_message);
149 DEFINE_string(ODsubdir, "", obj_detection_subdir_message);
151 /// @brief Define parameter for a type of Object Detection network
152 DEFINE_string(ODkind, "SSD", obj_detection_kind_message);
154 /// @brief Define parameter for GPU kernels path <br>
156 DEFINE_string(c, "", custom_cldnn_message);
158 /// @brief Define parameter for a path to CPU library with user layers <br>
159 /// It is an optional parameter
160 DEFINE_string(l, "", custom_cpu_library_message);
162 /// @brief Flag for printing plain text
163 DEFINE_bool(plain, false, plain_output_message);
165 DEFINE_string(lbl, "", labels_file_message);
168 * @brief This function shows a help message
170 static void showUsage() {
171 std::cout << std::endl;
172 std::cout << "Usage: validation_app [OPTION]" << std::endl << std::endl;
173 std::cout << "Available options:" << std::endl;
174 std::cout << std::endl;
175 std::cout << " -h " << help_message << std::endl;
176 std::cout << " -t <type> " << type_message << std::endl;
177 for (int i = 0; types_descriptions[i][0] != nullptr; i++) {
178 std::cout << " -t \"" << types_descriptions[i][0] << "\" for " << types_descriptions[i][1] << std::endl;
180 std::cout << " -i <path> " << image_message << std::endl;
181 std::cout << " -m <path> " << model_message << std::endl;
182 std::cout << " -lbl <path> " << labels_file_message << std::endl;
183 std::cout << " -l <absolute_path> " << custom_cpu_library_message << std::endl;
184 std::cout << " -c <absolute_path> " << custom_cldnn_message << std::endl;
185 std::cout << " -d <device> " << target_device_message << std::endl;
186 std::cout << " -b N " << batch_message << std::endl;
187 std::cout << " -ppType <type> " << preprocessing_type << std::endl;
188 std::cout << " -ppSize N " << preprocessing_size << std::endl;
189 std::cout << " -ppWidth W " << preprocessing_width << std::endl;
190 std::cout << " -ppHeight H " << preprocessing_height << std::endl;
191 std::cout << " --dump " << dump_message << std::endl;
193 std::cout << std::endl;
194 std::cout << " Classification-specific options:" << std::endl;
195 std::cout << " -Czb true " << zero_background_message << std::endl;
197 std::cout << std::endl;
198 std::cout << " Object detection-specific options:" << std::endl;
199 std::cout << " -ODkind <kind> " << obj_detection_kind_message << std::endl;
200 std::cout << " -ODa <path> " << obj_detection_annotations_message << std::endl;
201 std::cout << " -ODc <file> " << obj_detection_classes_message << std::endl;
202 std::cout << " -ODsubdir <name> " << obj_detection_subdir_message << std::endl << std::endl;
211 std::string strtolower(const std::string& s) {
213 std::transform(res.begin(), res.end(), res.begin(), ::tolower);
218 * @brief The main function of Inference Engine sample application
219 * @param argc - The number of arguments
220 * @param argv - Arguments
221 * @return 0 if all good
223 int main(int argc, char *argv[]) {
225 slog::info << "InferenceEngine: " << GetInferenceEngineVersion() << slog::endl;
227 // ---------------------------Parsing and validating input arguments--------------------------------------
228 slog::info << "Parsing input parameters" << slog::endl;
230 bool noOptions = argc == 1;
232 gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
233 if (FLAGS_h || noOptions) {
240 NetworkType netType = Undefined;
241 // Checking the network type
242 if (std::string(FLAGS_t) == "C") {
243 netType = Classification;
244 } else if (std::string(FLAGS_t) == "OD") {
245 netType = ObjDetection;
247 ee << UserException(5, "Unknown network type specified (invalid -t option)");
250 // Checking required options
251 if (FLAGS_m.empty()) ee << UserException(3, "Model file is not specified (missing -m option)");
252 if (FLAGS_i.empty()) ee << UserException(4, "Images list is not specified (missing -i option)");
253 if (FLAGS_d.empty()) ee << UserException(5, "Target device is not specified (missing -d option)");
254 if (FLAGS_b < 0) ee << UserException(6, "Batch must be positive (invalid -b option value)");
256 if (netType == ObjDetection) {
257 // Checking required OD-specific options
258 if (FLAGS_ODa.empty()) ee << UserException(11, "Annotations folder is not specified for object detection (missing -a option)");
259 if (FLAGS_ODc.empty()) ee << UserException(12, "Classes file is not specified (missing -c option)");
262 if (!ee.empty()) throw ee;
263 // -----------------------------------------------------------------------------------------------------
265 // ---------------------Loading plugin for Inference Engine------------------------------------------------
266 slog::info << "Loading plugin" << slog::endl;
267 /** Loading the library with extensions if provided**/
268 InferencePlugin plugin = PluginDispatcher({ FLAGS_pp }).getPluginByDevice(FLAGS_d);
270 /** Loading default extensions **/
271 if (FLAGS_d.find("CPU") != std::string::npos) {
273 * cpu_extensions library is compiled from "extension" folder containing
274 * custom CPU plugin layer implementations. These layers are not supported
275 * by CPU, but they can be useful for inferring custom topologies.
277 plugin.AddExtension(std::make_shared<Extensions::Cpu::CpuExtensions>());
280 if (!FLAGS_l.empty()) {
281 // CPU extensions are loaded as a shared library and passed as a pointer to base extension
282 IExtensionPtr extension_ptr = make_so_pointer<IExtension>(FLAGS_l);
283 plugin.AddExtension(extension_ptr);
284 slog::info << "CPU Extension loaded: " << FLAGS_l << slog::endl;
286 if (!FLAGS_c.empty()) {
287 // CPU extensions are loaded from an .xml description and OpenCL kernel files
288 plugin.SetConfig({{PluginConfigParams::KEY_CONFIG_FILE, FLAGS_c}});
289 slog::info << "GPU Extension loaded: " << FLAGS_c << slog::endl;
292 printPluginVersion(plugin, std::cout);
294 CsvDumper dumper(FLAGS_dump);
296 std::shared_ptr<Processor> processor;
298 PreprocessingOptions preprocessingOptions;
299 if (strtolower(FLAGS_ppType.c_str()) == "none") {
300 preprocessingOptions = PreprocessingOptions(false, ResizeCropPolicy::DoNothing);
301 } else if (strtolower(FLAGS_ppType) == "resizecrop") {
302 size_t ppWidth = FLAGS_ppSize;
303 size_t ppHeight = FLAGS_ppSize;
305 if (FLAGS_ppWidth > 0) ppWidth = FLAGS_ppSize;
306 if (FLAGS_ppHeight > 0) ppHeight = FLAGS_ppSize;
308 if (FLAGS_ppSize > 0 || (FLAGS_ppWidth > 0 && FLAGS_ppHeight > 0)) {
309 preprocessingOptions = PreprocessingOptions(false, ResizeCropPolicy::ResizeThenCrop, ppWidth, ppHeight);
311 THROW_USER_EXCEPTION(2) << "Size must be specified for preprocessing type " << FLAGS_ppType;
313 } else if (strtolower(FLAGS_ppType) == "resize" || FLAGS_ppType.empty()) {
314 preprocessingOptions = PreprocessingOptions(false, ResizeCropPolicy::Resize);
316 THROW_USER_EXCEPTION(2) << "Unknown preprocessing type: " << FLAGS_ppType;
319 if (netType == Classification) {
320 processor = std::shared_ptr<Processor>(
321 new ClassificationProcessor(FLAGS_m, FLAGS_d, FLAGS_i, FLAGS_b,
322 plugin, dumper, FLAGS_lbl, preprocessingOptions, FLAGS_Czb));
323 } else if (netType == ObjDetection) {
324 if (FLAGS_ODkind == "SSD") {
325 processor = std::shared_ptr<Processor>(
326 new SSDObjectDetectionProcessor(FLAGS_m, FLAGS_d, FLAGS_i, FLAGS_ODsubdir, FLAGS_b,
327 0.5, plugin, dumper, FLAGS_ODa, FLAGS_ODc));
328 } else if (FLAGS_ODkind == "YOLO") {
329 processor = std::shared_ptr<Processor>(
330 new YOLOObjectDetectionProcessor(FLAGS_m, FLAGS_d, FLAGS_i, FLAGS_ODsubdir, FLAGS_b,
331 0.5, plugin, dumper, FLAGS_ODa, FLAGS_ODc));
334 THROW_USER_EXCEPTION(2) << "Unknown network type specified" << FLAGS_ppType;
336 if (!processor.get()) {
337 THROW_USER_EXCEPTION(2) << "Processor pointer is invalid" << FLAGS_ppType;
339 slog::info << (FLAGS_d.empty() ? "Plugin: " + FLAGS_p : "Device: " + FLAGS_d) << slog::endl;
340 shared_ptr<Processor::InferenceMetrics> pIM = processor->Process(FLAGS_plain);
341 processor->Report(*pIM.get());
343 if (dumper.dumpEnabled()) {
344 slog::info << "Dump file generated: " << dumper.getFilename() << slog::endl;
346 } catch (const InferenceEngineException& ex) {
347 slog::err << "Inference problem: \n" << ex.what() << slog::endl;
349 } catch (const UserException& ex) {
350 slog::err << "Input problem: \n" << ex.what() << slog::endl;
352 return ex.exitCode();
353 } catch (const UserExceptions& ex) {
354 if (ex.list().size() == 1) {
355 slog::err << "Input problem: " << ex.what() << slog::endl;
357 return ex.list().begin()->exitCode();
359 slog::err << "Input problems: \n" << ex.what() << slog::endl;
361 return ex.list().begin()->exitCode();