1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
12 #include <inference_engine.hpp>
13 #include <ext_list.hpp>
14 #include <format_reader_ptr.h>
16 #include <samples/common.hpp>
17 #include <samples/slog.hpp>
18 #include <samples/args_helper.hpp>
20 #include "classification_sample.h"
22 using namespace InferenceEngine;
24 ConsoleErrorListener error_listener;
26 bool ParseAndCheckCommandLine(int argc, char *argv[]) {
27 // ---------------------------Parsing and validation of input args--------------------------------------
28 gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
33 slog::info << "Parsing input parameters" << slog::endl;
36 throw std::logic_error("Parameter -ni should be greater than zero (default 1)");
39 if (FLAGS_i.empty()) {
40 throw std::logic_error("Parameter -i is not set");
43 if (FLAGS_m.empty()) {
44 throw std::logic_error("Parameter -m is not set");
51 * @brief The entry point the Inference Engine sample application
52 * @file classification_sample/main.cpp
53 * @example classification_sample/main.cpp
55 int main(int argc, char *argv[]) {
57 slog::info << "InferenceEngine: " << GetInferenceEngineVersion() << slog::endl;
59 // ------------------------------ Parsing and validation of input args ---------------------------------
60 if (!ParseAndCheckCommandLine(argc, argv)) {
64 /** This vector stores paths to the processed images **/
65 std::vector<std::string> imageNames;
66 parseInputFilesArguments(imageNames);
67 if (imageNames.empty()) throw std::logic_error("No suitable images were found");
68 // -----------------------------------------------------------------------------------------------------
70 // --------------------------- 1. Load Plugin for inference engine -------------------------------------
71 slog::info << "Loading plugin" << slog::endl;
72 InferencePlugin plugin = PluginDispatcher({ FLAGS_pp, "../../../lib/intel64" , "" }).getPluginByDevice(FLAGS_d);
74 static_cast<InferenceEngine::InferenceEnginePluginPtr>(plugin)->SetLogCallback(error_listener);
77 /** Loading default extensions **/
78 if (FLAGS_d.find("CPU") != std::string::npos) {
80 * cpu_extensions library is compiled from "extension" folder containing
81 * custom MKLDNNPlugin layer implementations. These layers are not supported
82 * by mkldnn, but they can be useful for inferring custom topologies.
84 plugin.AddExtension(std::make_shared<Extensions::Cpu::CpuExtensions>());
87 if (!FLAGS_l.empty()) {
88 // CPU(MKLDNN) extensions are loaded as a shared library and passed as a pointer to base extension
89 auto extension_ptr = make_so_pointer<IExtension>(FLAGS_l);
90 plugin.AddExtension(extension_ptr);
91 slog::info << "CPU Extension loaded: " << FLAGS_l << slog::endl;
93 if (!FLAGS_c.empty()) {
94 // clDNN Extensions are loaded from an .xml description and OpenCL kernel files
95 plugin.SetConfig({{PluginConfigParams::KEY_CONFIG_FILE, FLAGS_c}});
96 slog::info << "GPU Extension loaded: " << FLAGS_c << slog::endl;
99 /** Setting plugin parameter for collecting per layer metrics **/
101 plugin.SetConfig({ { PluginConfigParams::KEY_PERF_COUNT, PluginConfigParams::YES } });
104 /** Printing plugin version **/
105 printPluginVersion(plugin, std::cout);
106 // -----------------------------------------------------------------------------------------------------
108 // --------------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ------------
109 std::string binFileName = fileNameNoExt(FLAGS_m) + ".bin";
110 slog::info << "Loading network files:"
112 "\n\t" << binFileName <<
115 CNNNetReader networkReader;
116 /** Reading network model **/
117 networkReader.ReadNetwork(FLAGS_m);
119 /** Extracting model name and loading weights **/
120 networkReader.ReadWeights(binFileName);
121 CNNNetwork network = networkReader.getNetwork();
122 // -----------------------------------------------------------------------------------------------------
124 // --------------------------- 3. Configure input & output ---------------------------------------------
126 // --------------------------- Prepare input blobs -----------------------------------------------------
127 slog::info << "Preparing input blobs" << slog::endl;
129 /** Taking information about all topology inputs **/
130 InputsDataMap inputInfo = network.getInputsInfo();
131 if (inputInfo.size() != 1) throw std::logic_error("Sample supports topologies only with 1 input");
133 auto inputInfoItem = *inputInfo.begin();
135 /** Specifying the precision and layout of input data provided by the user.
136 * This should be called before load of the network to the plugin **/
137 inputInfoItem.second->setPrecision(Precision::U8);
138 inputInfoItem.second->setLayout(Layout::NCHW);
140 std::vector<std::shared_ptr<unsigned char>> imagesData;
141 for (auto & i : imageNames) {
142 FormatReader::ReaderPtr reader(i.c_str());
143 if (reader.get() == nullptr) {
144 slog::warn << "Image " + i + " cannot be read!" << slog::endl;
147 /** Store image data **/
148 std::shared_ptr<unsigned char> data(
149 reader->getData(inputInfoItem.second->getTensorDesc().getDims()[3],
150 inputInfoItem.second->getTensorDesc().getDims()[2]));
151 if (data.get() != nullptr) {
152 imagesData.push_back(data);
155 if (imagesData.empty()) throw std::logic_error("Valid input images were not found!");
157 /** Setting batch size using image count **/
158 network.setBatchSize(imagesData.size());
159 size_t batchSize = network.getBatchSize();
160 slog::info << "Batch size is " << std::to_string(batchSize) << slog::endl;
162 // ------------------------------ Prepare output blobs -------------------------------------------------
163 slog::info << "Preparing output blobs" << slog::endl;
165 OutputsDataMap outputInfo(network.getOutputsInfo());
166 // BlobMap outputBlobs;
167 std::string firstOutputName;
169 for (auto & item : outputInfo) {
170 if (firstOutputName.empty()) {
171 firstOutputName = item.first;
173 DataPtr outputData = item.second;
175 throw std::logic_error("output data pointer is not valid");
178 item.second->setPrecision(Precision::FP32);
181 const SizeVector outputDims = outputInfo.begin()->second->getDims();
183 bool outputCorrect = false;
184 if (outputDims.size() == 2 /* NC */) {
185 outputCorrect = true;
186 } else if (outputDims.size() == 4 /* NCHW */) {
188 if (outputDims[2] == 1 && outputDims[3] == 1) outputCorrect = true;
191 if (!outputCorrect) {
192 throw std::logic_error("Incorrect output dimensions for classification model");
194 // -----------------------------------------------------------------------------------------------------
196 // --------------------------- 4. Loading model to the plugin ------------------------------------------
197 slog::info << "Loading model to the plugin" << slog::endl;
199 ExecutableNetwork executable_network = plugin.LoadNetwork(network, {});
200 inputInfoItem.second = {};
204 // -----------------------------------------------------------------------------------------------------
206 // --------------------------- 5. Create infer request -------------------------------------------------
207 InferRequest infer_request = executable_network.CreateInferRequest();
208 // -----------------------------------------------------------------------------------------------------
210 // --------------------------- 6. Prepare input --------------------------------------------------------
211 /** Iterate over all the input blobs **/
212 for (const auto & item : inputInfo) {
213 /** Creating input blob **/
214 Blob::Ptr input = infer_request.GetBlob(item.first);
216 /** Filling input tensor with images. First b channel, then g and r channels **/
217 size_t num_channels = input->getTensorDesc().getDims()[1];
218 size_t image_size = input->getTensorDesc().getDims()[2] * input->getTensorDesc().getDims()[3];
220 auto data = input->buffer().as<PrecisionTrait<Precision::U8>::value_type*>();
222 /** Iterate over all input images **/
223 for (size_t image_id = 0; image_id < imagesData.size(); ++image_id) {
224 /** Iterate over all pixel in image (b,g,r) **/
225 for (size_t pid = 0; pid < image_size; pid++) {
226 /** Iterate over all channels **/
227 for (size_t ch = 0; ch < num_channels; ++ch) {
228 /** [images stride + channels stride + pixel id ] all in bytes **/
229 data[image_id * image_size * num_channels + ch * image_size + pid ] = imagesData.at(image_id).get()[pid*num_channels + ch];
235 // -----------------------------------------------------------------------------------------------------
237 // --------------------------- 7. Do inference ---------------------------------------------------------
238 slog::info << "Starting inference (" << FLAGS_ni << " iterations)" << slog::endl;
240 typedef std::chrono::high_resolution_clock Time;
241 typedef std::chrono::duration<double, std::ratio<1, 1000>> ms;
242 typedef std::chrono::duration<float> fsec;
245 /** Start inference & calc performance **/
246 for (int iter = 0; iter < FLAGS_ni; ++iter) {
247 auto t0 = Time::now();
248 infer_request.Infer();
249 auto t1 = Time::now();
251 ms d = std::chrono::duration_cast<ms>(fs);
254 // -----------------------------------------------------------------------------------------------------
256 // --------------------------- 8. Process output -------------------------------------------------------
257 slog::info << "Processing output blobs" << slog::endl;
259 const Blob::Ptr output_blob = infer_request.GetBlob(firstOutputName);
260 auto output_data = output_blob->buffer().as<PrecisionTrait<Precision::FP32>::value_type*>();
262 /** Validating -nt value **/
263 const int resultsCnt = output_blob->size() / batchSize;
264 if (FLAGS_nt > resultsCnt || FLAGS_nt < 1) {
265 slog::warn << "-nt " << FLAGS_nt << " is not available for this network (-nt should be less than " \
266 << resultsCnt+1 << " and more than 0)\n will be used maximal value : " << resultsCnt;
267 FLAGS_nt = resultsCnt;
270 /** This vector stores id's of top N results **/
271 std::vector<unsigned> results;
272 TopResults(FLAGS_nt, *output_blob, results);
274 std::cout << std::endl << "Top " << FLAGS_nt << " results:" << std::endl << std::endl;
276 /** Read labels from file (e.x. AlexNet.labels) **/
277 bool labelsEnabled = false;
278 std::string labelFileName = fileNameNoExt(FLAGS_m) + ".labels";
279 std::vector<std::string> labels;
281 std::ifstream inputFile;
282 inputFile.open(labelFileName, std::ios::in);
283 if (inputFile.is_open()) {
285 while (std::getline(inputFile, strLine)) {
287 labels.push_back(strLine);
289 labelsEnabled = true;
292 /** Print the result iterating over each batch **/
293 for (int image_id = 0; image_id < batchSize; ++image_id) {
294 std::cout << "Image " << imageNames[image_id] << std::endl << std::endl;
295 for (size_t id = image_id * FLAGS_nt, cnt = 0; cnt < FLAGS_nt; ++cnt, ++id) {
296 std::cout.precision(7);
297 /** Getting probability for resulting class **/
298 const auto result = output_data[results[id] + image_id*(output_blob->size() / batchSize)];
299 std::cout << std::left << std::fixed << results[id] << " " << result;
301 std::cout << " label " << labels[results[id]] << std::endl;
303 std::cout << " label #" << results[id] << std::endl;
306 std::cout << std::endl;
308 // -----------------------------------------------------------------------------------------------------
309 std::cout << std::endl << "total inference time: " << total << std::endl;
310 std::cout << "Average running time of one iteration: " << total / static_cast<double>(FLAGS_ni) << " ms" << std::endl;
311 std::cout << std::endl << "Throughput: " << 1000 * static_cast<double>(FLAGS_ni) * batchSize / total << " FPS" << std::endl;
312 std::cout << std::endl;
314 /** Show performance results **/
316 printPerformanceCounts(infer_request, std::cout);
319 catch (const std::exception& error) {
320 slog::err << "" << error.what() << slog::endl;
324 slog::err << "Unknown/internal exception happened." << slog::endl;
328 slog::info << "Execution successful" << slog::endl;