15500bc645bd425d9e64ee107650937e4111db48
[platform/upstream/dldt.git] / inference-engine / samples / classification_sample / main.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <fstream>
7 #include <vector>
8 #include <chrono>
9 #include <memory>
10 #include <string>
11
12 #include <inference_engine.hpp>
13 #include <ext_list.hpp>
14 #include <format_reader_ptr.h>
15
16 #include <samples/common.hpp>
17 #include <samples/slog.hpp>
18 #include <samples/args_helper.hpp>
19
20 #include "classification_sample.h"
21
22 using namespace InferenceEngine;
23
24 ConsoleErrorListener error_listener;
25
26 bool ParseAndCheckCommandLine(int argc, char *argv[]) {
27     // ---------------------------Parsing and validation of input args--------------------------------------
28     gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
29     if (FLAGS_h) {
30         showUsage();
31         return false;
32     }
33     slog::info << "Parsing input parameters" << slog::endl;
34
35     if (FLAGS_ni < 1) {
36         throw std::logic_error("Parameter -ni should be greater than zero (default 1)");
37     }
38
39     if (FLAGS_i.empty()) {
40         throw std::logic_error("Parameter -i is not set");
41     }
42
43     if (FLAGS_m.empty()) {
44         throw std::logic_error("Parameter -m is not set");
45     }
46
47     return true;
48 }
49
50 /**
51 * @brief The entry point the Inference Engine sample application
52 * @file classification_sample/main.cpp
53 * @example classification_sample/main.cpp
54 */
55 int main(int argc, char *argv[]) {
56     try {
57         slog::info << "InferenceEngine: " << GetInferenceEngineVersion() << slog::endl;
58
59         // ------------------------------ Parsing and validation of input args ---------------------------------
60         if (!ParseAndCheckCommandLine(argc, argv)) {
61             return 0;
62         }
63
64         /** This vector stores paths to the processed images **/
65         std::vector<std::string> imageNames;
66         parseInputFilesArguments(imageNames);
67         if (imageNames.empty()) throw std::logic_error("No suitable images were found");
68         // -----------------------------------------------------------------------------------------------------
69
70         // --------------------------- 1. Load Plugin for inference engine -------------------------------------
71         slog::info << "Loading plugin" << slog::endl;
72         InferencePlugin plugin = PluginDispatcher({ FLAGS_pp, "../../../lib/intel64" , "" }).getPluginByDevice(FLAGS_d);
73         if (FLAGS_p_msg) {
74             static_cast<InferenceEngine::InferenceEnginePluginPtr>(plugin)->SetLogCallback(error_listener);
75         }
76
77         /** Loading default extensions **/
78         if (FLAGS_d.find("CPU") != std::string::npos) {
79             /**
80              * cpu_extensions library is compiled from "extension" folder containing
81              * custom MKLDNNPlugin layer implementations. These layers are not supported
82              * by mkldnn, but they can be useful for inferring custom topologies.
83             **/
84             plugin.AddExtension(std::make_shared<Extensions::Cpu::CpuExtensions>());
85         }
86
87         if (!FLAGS_l.empty()) {
88             // CPU(MKLDNN) extensions are loaded as a shared library and passed as a pointer to base extension
89             auto extension_ptr = make_so_pointer<IExtension>(FLAGS_l);
90             plugin.AddExtension(extension_ptr);
91             slog::info << "CPU Extension loaded: " << FLAGS_l << slog::endl;
92         }
93         if (!FLAGS_c.empty()) {
94             // clDNN Extensions are loaded from an .xml description and OpenCL kernel files
95             plugin.SetConfig({{PluginConfigParams::KEY_CONFIG_FILE, FLAGS_c}});
96             slog::info << "GPU Extension loaded: " << FLAGS_c << slog::endl;
97         }
98
99         /** Setting plugin parameter for collecting per layer metrics **/
100         if (FLAGS_pc) {
101             plugin.SetConfig({ { PluginConfigParams::KEY_PERF_COUNT, PluginConfigParams::YES } });
102         }
103
104         /** Printing plugin version **/
105         printPluginVersion(plugin, std::cout);
106         // -----------------------------------------------------------------------------------------------------
107
108         // --------------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ------------
109         std::string binFileName = fileNameNoExt(FLAGS_m) + ".bin";
110         slog::info << "Loading network files:"
111                 "\n\t" << FLAGS_m <<
112                 "\n\t" << binFileName <<
113         slog::endl;
114
115         CNNNetReader networkReader;
116         /** Reading network model **/
117         networkReader.ReadNetwork(FLAGS_m);
118
119         /** Extracting model name and loading weights **/
120         networkReader.ReadWeights(binFileName);
121         CNNNetwork network = networkReader.getNetwork();
122         // -----------------------------------------------------------------------------------------------------
123
124         // --------------------------- 3. Configure input & output ---------------------------------------------
125
126         // --------------------------- Prepare input blobs -----------------------------------------------------
127         slog::info << "Preparing input blobs" << slog::endl;
128
129         /** Taking information about all topology inputs **/
130         InputsDataMap inputInfo = network.getInputsInfo();
131         if (inputInfo.size() != 1) throw std::logic_error("Sample supports topologies only with 1 input");
132
133         auto inputInfoItem = *inputInfo.begin();
134
135         /** Specifying the precision and layout of input data provided by the user.
136          * This should be called before load of the network to the plugin **/
137         inputInfoItem.second->setPrecision(Precision::U8);
138         inputInfoItem.second->setLayout(Layout::NCHW);
139
140         std::vector<std::shared_ptr<unsigned char>> imagesData;
141         for (auto & i : imageNames) {
142             FormatReader::ReaderPtr reader(i.c_str());
143             if (reader.get() == nullptr) {
144                 slog::warn << "Image " + i + " cannot be read!" << slog::endl;
145                 continue;
146             }
147             /** Store image data **/
148             std::shared_ptr<unsigned char> data(
149                     reader->getData(inputInfoItem.second->getTensorDesc().getDims()[3],
150                                     inputInfoItem.second->getTensorDesc().getDims()[2]));
151             if (data.get() != nullptr) {
152                 imagesData.push_back(data);
153             }
154         }
155         if (imagesData.empty()) throw std::logic_error("Valid input images were not found!");
156
157         /** Setting batch size using image count **/
158         network.setBatchSize(imagesData.size());
159         size_t batchSize = network.getBatchSize();
160         slog::info << "Batch size is " << std::to_string(batchSize) << slog::endl;
161
162         // ------------------------------ Prepare output blobs -------------------------------------------------
163         slog::info << "Preparing output blobs" << slog::endl;
164
165         OutputsDataMap outputInfo(network.getOutputsInfo());
166         // BlobMap outputBlobs;
167         std::string firstOutputName;
168
169         for (auto & item : outputInfo) {
170             if (firstOutputName.empty()) {
171                 firstOutputName = item.first;
172             }
173             DataPtr outputData = item.second;
174             if (!outputData) {
175                 throw std::logic_error("output data pointer is not valid");
176             }
177
178             item.second->setPrecision(Precision::FP32);
179         }
180
181         const SizeVector outputDims = outputInfo.begin()->second->getDims();
182
183         bool outputCorrect = false;
184         if (outputDims.size() == 2 /* NC */) {
185             outputCorrect = true;
186         } else if (outputDims.size() == 4 /* NCHW */) {
187             /* H = W = 1 */
188             if (outputDims[2] == 1 && outputDims[3] == 1) outputCorrect = true;
189         }
190
191         if (!outputCorrect) {
192             throw std::logic_error("Incorrect output dimensions for classification model");
193         }
194         // -----------------------------------------------------------------------------------------------------
195
196         // --------------------------- 4. Loading model to the plugin ------------------------------------------
197         slog::info << "Loading model to the plugin" << slog::endl;
198
199         ExecutableNetwork executable_network = plugin.LoadNetwork(network, {});
200         inputInfoItem.second = {};
201         outputInfo = {};
202         network = {};
203         networkReader = {};
204         // -----------------------------------------------------------------------------------------------------
205
206         // --------------------------- 5. Create infer request -------------------------------------------------
207         InferRequest infer_request = executable_network.CreateInferRequest();
208         // -----------------------------------------------------------------------------------------------------
209
210         // --------------------------- 6. Prepare input --------------------------------------------------------
211         /** Iterate over all the input blobs **/
212         for (const auto & item : inputInfo) {
213             /** Creating input blob **/
214             Blob::Ptr input = infer_request.GetBlob(item.first);
215
216             /** Filling input tensor with images. First b channel, then g and r channels **/
217             size_t num_channels = input->getTensorDesc().getDims()[1];
218             size_t image_size = input->getTensorDesc().getDims()[2] * input->getTensorDesc().getDims()[3];
219
220             auto data = input->buffer().as<PrecisionTrait<Precision::U8>::value_type*>();
221
222             /** Iterate over all input images **/
223             for (size_t image_id = 0; image_id < imagesData.size(); ++image_id) {
224                 /** Iterate over all pixel in image (b,g,r) **/
225                 for (size_t pid = 0; pid < image_size; pid++) {
226                     /** Iterate over all channels **/
227                     for (size_t ch = 0; ch < num_channels; ++ch) {
228                         /**          [images stride + channels stride + pixel id ] all in bytes            **/
229                         data[image_id * image_size * num_channels + ch * image_size + pid ] = imagesData.at(image_id).get()[pid*num_channels + ch];
230                     }
231                 }
232             }
233         }
234         inputInfo = {};
235         // -----------------------------------------------------------------------------------------------------
236
237         // --------------------------- 7. Do inference ---------------------------------------------------------
238         slog::info << "Starting inference (" << FLAGS_ni << " iterations)" << slog::endl;
239
240         typedef std::chrono::high_resolution_clock Time;
241         typedef std::chrono::duration<double, std::ratio<1, 1000>> ms;
242         typedef std::chrono::duration<float> fsec;
243
244         double total = 0.0;
245         /** Start inference & calc performance **/
246         for (int iter = 0; iter < FLAGS_ni; ++iter) {
247             auto t0 = Time::now();
248             infer_request.Infer();
249             auto t1 = Time::now();
250             fsec fs = t1 - t0;
251             ms d = std::chrono::duration_cast<ms>(fs);
252             total += d.count();
253         }
254         // -----------------------------------------------------------------------------------------------------
255
256         // --------------------------- 8. Process output -------------------------------------------------------
257         slog::info << "Processing output blobs" << slog::endl;
258
259         const Blob::Ptr output_blob = infer_request.GetBlob(firstOutputName);
260         auto output_data = output_blob->buffer().as<PrecisionTrait<Precision::FP32>::value_type*>();
261
262         /** Validating -nt value **/
263         const int resultsCnt = output_blob->size() / batchSize;
264         if (FLAGS_nt > resultsCnt || FLAGS_nt < 1) {
265             slog::warn << "-nt " << FLAGS_nt << " is not available for this network (-nt should be less than " \
266                       << resultsCnt+1 << " and more than 0)\n            will be used maximal value : " << resultsCnt;
267             FLAGS_nt = resultsCnt;
268         }
269
270         /** This vector stores id's of top N results **/
271         std::vector<unsigned> results;
272         TopResults(FLAGS_nt, *output_blob, results);
273
274         std::cout << std::endl << "Top " << FLAGS_nt << " results:" << std::endl << std::endl;
275
276         /** Read labels from file (e.x. AlexNet.labels) **/
277         bool labelsEnabled = false;
278         std::string labelFileName = fileNameNoExt(FLAGS_m) + ".labels";
279         std::vector<std::string> labels;
280
281         std::ifstream inputFile;
282         inputFile.open(labelFileName, std::ios::in);
283         if (inputFile.is_open()) {
284             std::string strLine;
285             while (std::getline(inputFile, strLine)) {
286                 trim(strLine);
287                 labels.push_back(strLine);
288             }
289             labelsEnabled = true;
290         }
291
292         /** Print the result iterating over each batch **/
293         for (int image_id = 0; image_id < batchSize; ++image_id) {
294             std::cout << "Image " << imageNames[image_id] << std::endl << std::endl;
295             for (size_t id = image_id * FLAGS_nt, cnt = 0; cnt < FLAGS_nt; ++cnt, ++id) {
296                 std::cout.precision(7);
297                 /** Getting probability for resulting class **/
298                 const auto result = output_data[results[id] + image_id*(output_blob->size() / batchSize)];
299                 std::cout << std::left << std::fixed << results[id] << " " << result;
300                 if (labelsEnabled) {
301                     std::cout << " label " << labels[results[id]] << std::endl;
302                 } else {
303                     std::cout << " label #" << results[id] << std::endl;
304                 }
305             }
306             std::cout << std::endl;
307         }
308         // -----------------------------------------------------------------------------------------------------
309         std::cout << std::endl << "total inference time: " << total << std::endl;
310         std::cout << "Average running time of one iteration: " << total / static_cast<double>(FLAGS_ni) << " ms" << std::endl;
311         std::cout << std::endl << "Throughput: " << 1000 * static_cast<double>(FLAGS_ni) * batchSize / total << " FPS" << std::endl;
312         std::cout << std::endl;
313
314         /** Show performance results **/
315         if (FLAGS_pc) {
316             printPerformanceCounts(infer_request, std::cout);
317         }
318     }
319     catch (const std::exception& error) {
320         slog::err << "" << error.what() << slog::endl;
321         return 1;
322     }
323     catch (...) {
324         slog::err << "Unknown/internal exception happened." << slog::endl;
325         return 1;
326     }
327
328     slog::info << "Execution successful" << slog::endl;
329     return 0;
330 }