Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / samples / validation_app / main.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 /**
6  * @brief The entry point for Inference Engine validation application
7  * @file validation_app/main.cpp
8  */
9 #include <gflags/gflags.h>
10 #include <algorithm>
11 #include <functional>
12 #include <iostream>
13 #include <map>
14 #include <fstream>
15 #include <random>
16 #include <string>
17 #include <tuple>
18 #include <vector>
19 #include <limits>
20 #include <iomanip>
21 #include <memory>
22
23 #include <ext_list.hpp>
24
25 #include <samples/common.hpp>
26 #include <samples/slog.hpp>
27
28 #include "user_exception.hpp"
29 #include "ClassificationProcessor.hpp"
30 #include "SSDObjectDetectionProcessor.hpp"
31 #include "YOLOObjectDetectionProcessor.hpp"
32
33 using namespace std;
34 using namespace InferenceEngine;
35
36 using InferenceEngine::details::InferenceEngineException;
37
38 /// @brief Message for help argument
39 static const char help_message[] = "Print a help message";
40 /// @brief Message for images argument
41 static const char image_message[] = "Required. Folder with validation images. Path to a directory with validation images. For Classification models,"
42                                     " the directory must contain folders named as labels with images inside or a .txt file with"
43                                     " a list of images. For Object Detection models, the dataset must be in"
44                                     " VOC format.";
45 /// @brief Message for plugin_path argument
46 static const char plugin_path_message[] = "Required. Path to an .xml file with a trained model, including model name and "
47                                           "extension.";
48 /// @brief Message for model argument
49 static const char model_message[] = "Required. Path to an .xml file with a trained model";
50 /// @brief Message for plugin argument
51 static const char plugin_message[] = "Plugin name. For example, CPU. If this parameter is passed, "
52                                      "the sample looks for a specified plugin only.";
53 /// @brief Message for assigning cnn calculation to device
54 static const char target_device_message[] = "Target device to infer on: CPU (default), GPU, FPGA, HDDL or MYRIAD."
55                                             " The application looks for a suitable plugin for the specified device.";
56 /// @brief Message for label argument
57 static const char label_message[] = "Path to a file with labels for a model";
58 /// @brief Message for batch argumenttype
59 static const char batch_message[] = "Batch size value. If not specified, the batch size value is taken from IR";
60 /// @brief Message for dump argument
61 static const char dump_message[] = "Dump file names and inference results to a .csv file";
62 /// @brief Message for network type
63 static const char type_message[] = "Type of an inferred network (\"C\" by default)";
64 /// @brief Message for pp-type
65 static const char preprocessing_type[] = "Preprocessing type. Options: \"None\", \"Resize\", \"ResizeCrop\"";
66 /// @brief Message for pp-crop-size
67 static const char preprocessing_size[] = "Preprocessing size (used with ppType=\"ResizeCrop\")";
68 static const char preprocessing_width[] = "Preprocessing width (overrides -ppSize, used with ppType=\"ResizeCrop\")";
69 static const char preprocessing_height[] = "Preprocessing height (overrides -ppSize, used with ppType=\"ResizeCrop\")";
70
71 static const char obj_detection_annotations_message[] = "Required for Object Detection models. Path to a directory"
72                                                         " containing an .xml file with annotations for images.";
73
74 static const char obj_detection_classes_message[] = "Required for Object Detection models. Path to a file containing"
75                                                     " a list of classes";
76
77 static const char obj_detection_subdir_message[] = "Directory between the path to images (specified with -i) and image name (specified in the"
78                                                    " .xml file). For VOC2007 dataset, use JPEGImages.";
79 static const char obj_detection_kind_message[] = "Type of an Object Detection model. Options: SSD";
80
81 /// @brief Message for GPU custom kernels desc
82 static const char custom_cldnn_message[] = "Required for GPU custom kernels."
83                                            "Absolute path to an .xml file with the kernel descriptions.";
84
85 /// @brief Message for user library argument
86 static const char custom_cpu_library_message[] = "Required for CPU custom layers. "
87                                                  "Absolute path to a shared library with the kernel implementations";
88
89 /// @brief Message for labels file
90 static const char labels_file_message[] = "Labels file path. The labels file contains names of the dataset classes";
91
92 static const char zero_background_message[] = "\"Zero is a background\" flag. Some networks are trained with a modified"
93                                               " dataset where the class IDs "
94                                               " are enumerated from 1, but 0 is an undefined \"background\" class"
95                                               " (which is never detected)";
96
97 static const char plain_output_message[] = "Flag for plain output";
98
99
100 /// @brief Network type options and their descriptions
101 static const char* types_descriptions[][2] = {
102     { "C", "classification" },
103 //    { "SS", "semantic segmentation" },    // Not supported yet
104     { "OD", "object detection" },
105     { nullptr, nullptr }
106 };
107
108 /// @brief Define flag for showing help message <br>
109 DEFINE_bool(h, false, help_message);
110 /// @brief Define parameter for a path to images <br>
111 /// It is a required parameter
112 DEFINE_string(i, "", image_message);
113 /// @brief Define parameter for a path to model file <br>
114 /// It is a required parameter
115 DEFINE_string(m, "", model_message);
116 /// @brief Define parameter for a plugin name <br>
117 /// It is a required parameter
118 DEFINE_string(p, "", plugin_message);
119 /// @brief Define parameter for a path to a file with labels <br>
120 /// Default is empty
121 DEFINE_string(OCl, "", label_message);
122 /// @brief Define parameter for a path to plugins <br>
123 /// Default is ./lib
124 DEFINE_string(pp, "", plugin_path_message);
125 /// @brief Define parameter for a target device to infer on <br>
126 DEFINE_string(d, "CPU", target_device_message);
127 /// @brief Define parameter for batch size <br>
128 /// Default is 0 (which means that batch size is not specified)
129 DEFINE_int32(b, 0, batch_message);
130 /// @brief Define flag to dump results to a file <br>
131 DEFINE_bool(dump, false, dump_message);
132 /// @brief Define parameter for a network type parameter
133 DEFINE_string(t, "C", type_message);
134
135 /// @brief Define parameter for preprocessing type
136 DEFINE_string(ppType, "", preprocessing_type);
137
138 /// @brief Define parameter for preprocessing size
139 DEFINE_int32(ppSize, 0, preprocessing_size);
140 DEFINE_int32(ppWidth, 0, preprocessing_width);
141 DEFINE_int32(ppHeight, 0, preprocessing_height);
142
143 DEFINE_bool(Czb, false, zero_background_message);
144
145 DEFINE_string(ODa, "", obj_detection_annotations_message);
146
147 DEFINE_string(ODc, "", obj_detection_classes_message);
148
149 DEFINE_string(ODsubdir, "", obj_detection_subdir_message);
150
151 /// @brief Define parameter for a type of Object Detection network
152 DEFINE_string(ODkind, "SSD", obj_detection_kind_message);
153
154 /// @brief Define parameter for GPU kernels path <br>
155 /// Default is ./lib
156 DEFINE_string(c, "", custom_cldnn_message);
157
158 /// @brief Define parameter for a path to CPU library with user layers <br>
159 /// It is an optional parameter
160 DEFINE_string(l, "", custom_cpu_library_message);
161
162 /// @brief Flag for printing plain text
163 DEFINE_bool(plain, false, plain_output_message);
164
165 DEFINE_string(lbl, "", labels_file_message);
166
167 /**
168  * @brief This function shows a help message
169  */
170 static void showUsage() {
171     std::cout << std::endl;
172     std::cout << "Usage: validation_app [OPTION]" << std::endl << std::endl;
173     std::cout << "Available options:" << std::endl;
174     std::cout << std::endl;
175     std::cout << "    -h                        " << help_message << std::endl;
176     std::cout << "    -t <type>                 " << type_message << std::endl;
177     for (int i = 0; types_descriptions[i][0] != nullptr; i++) {
178         std::cout << "      -t \"" << types_descriptions[i][0] << "\" for " << types_descriptions[i][1] << std::endl;
179     }
180     std::cout << "    -i <path>                 " << image_message << std::endl;
181     std::cout << "    -m <path>                 " << model_message << std::endl;
182     std::cout << "    -lbl <path>               " << labels_file_message << std::endl;
183     std::cout << "    -l <absolute_path>        " << custom_cpu_library_message << std::endl;
184     std::cout << "    -c <absolute_path>        " << custom_cldnn_message << std::endl;
185     std::cout << "    -d <device>               " << target_device_message << std::endl;
186     std::cout << "    -b N                      " << batch_message << std::endl;
187     std::cout << "    -ppType <type>            " << preprocessing_type << std::endl;
188     std::cout << "    -ppSize N                 " << preprocessing_size << std::endl;
189     std::cout << "    -ppWidth W                " << preprocessing_width << std::endl;
190     std::cout << "    -ppHeight H               " << preprocessing_height << std::endl;
191     std::cout << "    --dump                    " << dump_message << std::endl;
192
193     std::cout << std::endl;
194     std::cout << "    Classification-specific options:" << std::endl;
195     std::cout << "      -Czb true               " << zero_background_message << std::endl;
196
197     std::cout << std::endl;
198     std::cout << "    Object detection-specific options:" << std::endl;
199     std::cout << "      -ODkind <kind>          " << obj_detection_kind_message << std::endl;
200     std::cout << "      -ODa <path>             " << obj_detection_annotations_message << std::endl;
201     std::cout << "      -ODc <file>             " << obj_detection_classes_message << std::endl;
202     std::cout << "      -ODsubdir <name>        " << obj_detection_subdir_message << std::endl << std::endl;
203 }
204
205 enum NetworkType {
206     Undefined = -1,
207     Classification,
208     ObjDetection
209 };
210
211 std::string strtolower(const std::string& s) {
212     std::string res = s;
213     std::transform(res.begin(), res.end(), res.begin(), ::tolower);
214     return res;
215 }
216
217 /**
218  * @brief The main function of Inference Engine sample application
219  * @param argc - The number of arguments
220  * @param argv - Arguments
221  * @return 0 if all good
222  */
223 int main(int argc, char *argv[]) {
224     try {
225         slog::info << "InferenceEngine: " << GetInferenceEngineVersion() << slog::endl;
226
227         // ---------------------------Parsing and validating input arguments--------------------------------------
228         slog::info << "Parsing input parameters" << slog::endl;
229
230         bool noOptions = argc == 1;
231
232         gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
233         if (FLAGS_h || noOptions) {
234             showUsage();
235             return 1;
236         }
237
238         UserExceptions ee;
239
240         NetworkType netType = Undefined;
241         // Checking the network type
242         if (std::string(FLAGS_t) == "C") {
243             netType = Classification;
244         } else if (std::string(FLAGS_t) == "OD") {
245             netType = ObjDetection;
246         } else {
247             ee << UserException(5, "Unknown network type specified (invalid -t option)");
248         }
249
250         // Checking required options
251         if (FLAGS_m.empty()) ee << UserException(3, "Model file is not specified (missing -m option)");
252         if (FLAGS_i.empty()) ee << UserException(4, "Images list is not specified (missing -i option)");
253         if (FLAGS_d.empty()) ee << UserException(5, "Target device is not specified (missing -d option)");
254         if (FLAGS_b < 0) ee << UserException(6, "Batch must be positive (invalid -b option value)");
255
256         if (netType == ObjDetection) {
257             // Checking required OD-specific options
258             if (FLAGS_ODa.empty()) ee << UserException(11, "Annotations folder is not specified for object detection (missing -a option)");
259             if (FLAGS_ODc.empty()) ee << UserException(12, "Classes file is not specified (missing -c option)");
260         }
261
262         if (!ee.empty()) throw ee;
263         // -----------------------------------------------------------------------------------------------------
264
265         // ---------------------Loading plugin for Inference Engine------------------------------------------------
266         slog::info << "Loading plugin" << slog::endl;
267         /** Loading the library with extensions if provided**/
268         InferencePlugin plugin = PluginDispatcher({ FLAGS_pp }).getPluginByDevice(FLAGS_d);
269
270         /** Loading default extensions **/
271         if (FLAGS_d.find("CPU") != std::string::npos) {
272             /**
273              * cpu_extensions library is compiled from "extension" folder containing
274              * custom CPU plugin layer implementations. These layers are not supported
275              * by CPU, but they can be useful for inferring custom topologies.
276             **/
277             plugin.AddExtension(std::make_shared<Extensions::Cpu::CpuExtensions>());
278         }
279
280         if (!FLAGS_l.empty()) {
281             // CPU extensions are loaded as a shared library and passed as a pointer to base extension
282             IExtensionPtr extension_ptr = make_so_pointer<IExtension>(FLAGS_l);
283             plugin.AddExtension(extension_ptr);
284             slog::info << "CPU Extension loaded: " << FLAGS_l << slog::endl;
285         }
286         if (!FLAGS_c.empty()) {
287             // CPU extensions are loaded from an .xml description and OpenCL kernel files
288             plugin.SetConfig({{PluginConfigParams::KEY_CONFIG_FILE, FLAGS_c}});
289             slog::info << "GPU Extension loaded: " << FLAGS_c << slog::endl;
290         }
291
292         printPluginVersion(plugin, std::cout);
293
294         CsvDumper dumper(FLAGS_dump);
295
296         std::shared_ptr<Processor> processor;
297
298         PreprocessingOptions preprocessingOptions;
299         if (strtolower(FLAGS_ppType.c_str()) == "none") {
300             preprocessingOptions = PreprocessingOptions(false, ResizeCropPolicy::DoNothing);
301         } else if (strtolower(FLAGS_ppType) == "resizecrop") {
302             size_t ppWidth = FLAGS_ppSize;
303             size_t ppHeight = FLAGS_ppSize;
304
305             if (FLAGS_ppWidth > 0) ppWidth = FLAGS_ppSize;
306             if (FLAGS_ppHeight > 0) ppHeight = FLAGS_ppSize;
307
308             if (FLAGS_ppSize > 0 || (FLAGS_ppWidth > 0 && FLAGS_ppHeight > 0)) {
309                 preprocessingOptions = PreprocessingOptions(false, ResizeCropPolicy::ResizeThenCrop, ppWidth, ppHeight);
310             } else {
311                 THROW_USER_EXCEPTION(2) << "Size must be specified for preprocessing type " << FLAGS_ppType;
312             }
313         } else if (strtolower(FLAGS_ppType) == "resize" || FLAGS_ppType.empty()) {
314             preprocessingOptions = PreprocessingOptions(false, ResizeCropPolicy::Resize);
315         } else {
316             THROW_USER_EXCEPTION(2) << "Unknown preprocessing type: " << FLAGS_ppType;
317         }
318
319         if (netType == Classification) {
320             processor = std::shared_ptr<Processor>(
321                     new ClassificationProcessor(FLAGS_m, FLAGS_d, FLAGS_i, FLAGS_b,
322                                                 plugin, dumper, FLAGS_lbl, preprocessingOptions, FLAGS_Czb));
323         } else if (netType == ObjDetection) {
324             if (FLAGS_ODkind == "SSD") {
325                 processor = std::shared_ptr<Processor>(
326                         new SSDObjectDetectionProcessor(FLAGS_m, FLAGS_d, FLAGS_i, FLAGS_ODsubdir, FLAGS_b,
327                                                         0.5, plugin, dumper, FLAGS_ODa, FLAGS_ODc));
328             } else if (FLAGS_ODkind == "YOLO") {
329                 processor = std::shared_ptr<Processor>(
330                         new YOLOObjectDetectionProcessor(FLAGS_m, FLAGS_d, FLAGS_i, FLAGS_ODsubdir, FLAGS_b,
331                                                          0.5, plugin, dumper, FLAGS_ODa, FLAGS_ODc));
332             }
333         } else {
334             THROW_USER_EXCEPTION(2) <<  "Unknown network type specified" << FLAGS_ppType;
335         }
336         if (!processor.get()) {
337             THROW_USER_EXCEPTION(2) <<  "Processor pointer is invalid" << FLAGS_ppType;
338         }
339         slog::info << (FLAGS_d.empty() ? "Plugin: " + FLAGS_p : "Device: " + FLAGS_d) << slog::endl;
340         shared_ptr<Processor::InferenceMetrics> pIM = processor->Process(FLAGS_plain);
341         processor->Report(*pIM.get());
342
343         if (dumper.dumpEnabled()) {
344             slog::info << "Dump file generated: " << dumper.getFilename() << slog::endl;
345         }
346     } catch (const InferenceEngineException& ex) {
347         slog::err << "Inference problem: \n" << ex.what() << slog::endl;
348         return 1;
349     } catch (const UserException& ex) {
350         slog::err << "Input problem: \n" << ex.what() << slog::endl;
351         showUsage();
352         return ex.exitCode();
353     } catch (const UserExceptions& ex) {
354         if (ex.list().size() == 1) {
355             slog::err << "Input problem: " << ex.what() << slog::endl;
356             showUsage();
357             return ex.list().begin()->exitCode();
358         } else {
359             slog::err << "Input problems: \n" << ex.what() << slog::endl;
360             showUsage();
361             return ex.list().begin()->exitCode();
362         }
363     }
364     return 0;
365 }