2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
18 * @brief The entry point the Inference Engine sample application
19 * @file classification_sample/main.cpp
20 * @example classification_sample/main.cpp
31 #include <inference_engine.hpp>
33 #include <format_reader/format_reader_ptr.h>
35 #include <samples/common.hpp>
36 #include <samples/slog.hpp>
37 #include <samples/args_helper.hpp>
40 #include <ext_list.hpp>
42 #include "classification_sample_async.h"
44 using namespace InferenceEngine;
46 bool ParseAndCheckCommandLine(int argc, char *argv[]) {
47 // ---------------------------Parsing and validation of input args--------------------------------------
48 slog::info << "Parsing input parameters" << slog::endl;
50 gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
55 slog::info << "Parsing input parameters" << slog::endl;
58 throw std::logic_error("Parameter -ni must be more than 0 ! (default 1)");
61 if (FLAGS_nireq < 1) {
62 throw std::logic_error("Parameter -nireq must be more than 0 ! (default 1)");
65 if (FLAGS_i.empty()) {
66 throw std::logic_error("Parameter -i is not set");
69 if (FLAGS_m.empty()) {
70 throw std::logic_error("Parameter -m is not set");
73 if (FLAGS_ni < FLAGS_nireq) {
74 throw std::logic_error("Number of iterations could not be less than requests quantity");
80 int main(int argc, char *argv[]) {
82 slog::info << "InferenceEngine: " << GetInferenceEngineVersion() << slog::endl;
84 // ------------------------------ Parsing and validation of input args ---------------------------------
85 if (!ParseAndCheckCommandLine(argc, argv)) {
89 /** This vector stores paths to the processed images **/
90 std::vector<std::string> imageNames;
91 parseImagesArguments(imageNames);
92 if (imageNames.empty()) throw std::logic_error("No suitable images were found");
93 // -----------------------------------------------------------------------------------------------------
95 // --------------------------- 1. Load Plugin for inference engine -------------------------------------
96 slog::info << "Loading plugin" << slog::endl;
97 InferencePlugin plugin = PluginDispatcher({ FLAGS_pp, "../../../lib/intel64" , "" }).getPluginByDevice(FLAGS_d);
99 /** Loading default extensions **/
100 if (FLAGS_d.find("CPU") != std::string::npos) {
102 * cpu_extensions library is compiled from "extension" folder containing
103 * custom MKLDNNPlugin layer implementations. These layers are not supported
104 * by mkldnn, but they can be useful for inferring custom topologies.
106 plugin.AddExtension(std::make_shared<Extensions::Cpu::CpuExtensions>());
109 if (!FLAGS_l.empty()) {
110 // CPU(MKLDNN) extensions are loaded as a shared library and passed as a pointer to base extension
111 IExtensionPtr extension_ptr = make_so_pointer<IExtension>(FLAGS_l);
112 plugin.AddExtension(extension_ptr);
113 slog::info << "CPU Extension loaded: " << FLAGS_l << slog::endl;
115 if (!FLAGS_c.empty()) {
116 // clDNN Extensions are loaded from an .xml description and OpenCL kernel files
117 plugin.SetConfig({{PluginConfigParams::KEY_CONFIG_FILE, FLAGS_c}});
118 slog::info << "GPU Extension loaded: " << FLAGS_c << slog::endl;
122 /** Printing plugin version **/
123 printPluginVersion(plugin, std::cout);
124 // -----------------------------------------------------------------------------------------------------
126 // --------------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ------------
127 slog::info << "Loading network files" << slog::endl;
129 CNNNetReader networkReader;
130 /** Read network model **/
131 networkReader.ReadNetwork(FLAGS_m);
133 /** Extract model name and load weights **/
134 std::string binFileName = fileNameNoExt(FLAGS_m) + ".bin";
135 networkReader.ReadWeights(binFileName);
137 CNNNetwork network = networkReader.getNetwork();
138 // -----------------------------------------------------------------------------------------------------
140 // --------------------------- 3. Configure input & output ---------------------------------------------
142 // --------------------------- Prepare input blobs -----------------------------------------------------
143 slog::info << "Preparing input blobs" << slog::endl;
145 /** Taking information about all topology inputs **/
146 InputsDataMap inputInfo(network.getInputsInfo());
147 if (inputInfo.size() != 1) throw std::logic_error("Sample supports topologies only with 1 input");
149 auto inputInfoItem = *inputInfo.begin();
151 /** Specifying the precision and layout of input data provided by the user.
152 * This should be called before load of the network to the plugin **/
153 inputInfoItem.second->setPrecision(Precision::U8);
154 inputInfoItem.second->setLayout(Layout::NCHW);
156 std::vector<std::shared_ptr<unsigned char>> imagesData;
157 for (auto & i : imageNames) {
158 FormatReader::ReaderPtr reader(i.c_str());
159 if (reader.get() == nullptr) {
160 slog::warn << "Image " + i + " cannot be read!" << slog::endl;
163 /** Store image data **/
164 std::shared_ptr<unsigned char> data(
165 reader->getData(inputInfoItem.second->getTensorDesc().getDims()[3],
166 inputInfoItem.second->getTensorDesc().getDims()[2]));
167 if (data.get() != nullptr) {
168 imagesData.push_back(data);
171 if (imagesData.empty()) throw std::logic_error("Valid input images were not found!");
173 /** Setting batch size using image count **/
174 network.setBatchSize(imagesData.size());
175 size_t batchSize = network.getBatchSize();
176 slog::info << "Batch size is " << std::to_string(batchSize) << slog::endl;
178 // ------------------------------ Prepare output blobs -------------------------------------------------
179 slog::info << "Preparing output blobs" << slog::endl;
181 OutputsDataMap outputInfo(network.getOutputsInfo());
182 std::vector <Blob::Ptr> outputBlobs;
183 for (size_t i = 0; i < FLAGS_nireq; i++) {
184 auto outputBlob = make_shared_blob<PrecisionTrait<Precision::FP32>::value_type>(outputInfo.begin()->second->getTensorDesc());
185 outputBlob->allocate();
186 outputBlobs.push_back(outputBlob);
188 // -----------------------------------------------------------------------------------------------------
190 // --------------------------- 4. Loading model to the plugin ------------------------------------------
191 slog::info << "Loading model to the plugin" << slog::endl;
193 std::map<std::string, std::string> config;
195 config[PluginConfigParams::KEY_PERF_COUNT] = PluginConfigParams::YES;
198 ExecutableNetwork executable_network = plugin.LoadNetwork(network, {});
199 // -----------------------------------------------------------------------------------------------------
201 // --------------------------- 5. Create infer request -------------------------------------------------
202 std::vector<InferRequest> inferRequests;
203 for (size_t i = 0; i < FLAGS_nireq; i++) {
204 InferRequest inferRequest = executable_network.CreateInferRequest();
205 inferRequests.push_back(inferRequest);
207 // -----------------------------------------------------------------------------------------------------
209 // --------------------------- 6. Prepare input --------------------------------------------------------
211 for (auto & item : inputInfo) {
212 auto input = make_shared_blob<PrecisionTrait<Precision::U8>::value_type>(item.second->getTensorDesc());
214 inputBlobs[item.first] = input;
216 auto dims = input->getTensorDesc().getDims();
217 /** Fill input tensor with images. First b channel, then g and r channels **/
218 size_t num_channels = dims[1];
219 size_t image_size = dims[3] * dims[2];
221 /** Iterate over all input images **/
222 for (size_t image_id = 0; image_id < imagesData.size(); ++image_id) {
223 /** Iterate over all pixel in image (b,g,r) **/
224 for (size_t pid = 0; pid < image_size; pid++) {
225 /** Iterate over all channels **/
226 for (size_t ch = 0; ch < num_channels; ++ch) {
227 /** [images stride + channels stride + pixel id ] all in bytes **/
228 input->data()[image_id * image_size * num_channels + ch * image_size + pid] = imagesData.at(image_id).get()[pid*num_channels + ch];
234 for (size_t i = 0; i < FLAGS_nireq; i++) {
235 inferRequests[i].SetBlob(inputBlobs.begin()->first, inputBlobs.begin()->second);
236 inferRequests[i].SetBlob(outputInfo.begin()->first, outputBlobs[i]);
238 // -----------------------------------------------------------------------------------------------------
240 // --------------------------- 7. Do inference ---------------------------------------------------------
241 slog::info << "Start inference (" << FLAGS_ni << " iterations)" << slog::endl;
243 typedef std::chrono::high_resolution_clock Time;
244 typedef std::chrono::duration<double, std::ratio<1, 1000>> ms;
245 typedef std::chrono::duration<float> fsec;
248 /** Start inference & calc performance **/
249 auto t0 = Time::now();
251 size_t currentInfer = 0;
252 size_t prevInfer = (FLAGS_nireq > 1) ? 1 : 0;
256 inferRequests[0].StartAsync();
257 inferRequests[0].Wait(10000);
259 for (int iter = 0; iter < FLAGS_ni + FLAGS_nireq; ++iter) {
260 if (iter < FLAGS_ni) {
261 inferRequests[currentInfer].StartAsync();
263 inferRequests[prevInfer].Wait(10000);
266 if (currentInfer >= FLAGS_nireq) {
270 if (prevInfer >= FLAGS_nireq) {
274 auto t1 = Time::now();
276 ms d = std::chrono::duration_cast<ms>(fs);
278 // -----------------------------------------------------------------------------------------------------
280 // --------------------------- 8. Process output -------------------------------------------------------
281 slog::info << "Processing output blobs" << slog::endl;
283 for (size_t i = 0; i < FLAGS_nireq; i++) {
284 /** Validating -nt value **/
285 const int resultsCnt = outputBlobs[i]->size() / batchSize;
286 if (FLAGS_nt > resultsCnt || FLAGS_nt < 1) {
287 slog::warn << "-nt " << FLAGS_nt << " is not available for this network (-nt should be less than " \
288 << resultsCnt+1 << " and more than 0)\n will be used maximal value : " << resultsCnt << slog::endl;
289 FLAGS_nt = resultsCnt;
291 /** This vector stores id's of top N results **/
292 std::vector<unsigned> results;
293 TopResults(FLAGS_nt, *outputBlobs[i], results);
295 std::cout << std::endl << "Top " << FLAGS_nt << " results:" << std::endl << std::endl;
297 /** Read labels from file (e.x. AlexNet.labels) **/
298 bool labelsEnabled = false;
299 std::string labelFileName = fileNameNoExt(FLAGS_m) + ".labels";
300 std::vector<std::string> labels;
302 std::ifstream inputFile;
303 inputFile.open(labelFileName, std::ios::in);
304 if (inputFile.is_open()) {
306 while (std::getline(inputFile, strLine)) {
308 labels.push_back(strLine);
310 labelsEnabled = true;
313 /** Print the result iterating over each batch **/
314 for (int image_id = 0; image_id < batchSize; ++image_id) {
315 std::cout << "Image " << imageNames[image_id] << std::endl << std::endl;
316 for (size_t id = image_id * FLAGS_nt, cnt = 0; cnt < FLAGS_nt; ++cnt, ++id) {
317 std::cout.precision(7);
318 /** Getting probability for resulting class **/
319 auto result = outputBlobs[i]->buffer().
320 as<PrecisionTrait<Precision::FP32>::value_type*>()[results[id] + image_id*(outputBlobs[i]->size() / batchSize)];
321 std::cout << std::left << std::fixed << results[id] << " " << result;
323 std::cout << " label " << labels[results[id]] << std::endl;
325 std::cout << " label #" << results[id] << std::endl;
328 std::cout << std::endl;
331 // -----------------------------------------------------------------------------------------------------
332 std::cout << std::endl << "total inference time: " << total << std::endl;
333 std::cout << std::endl << "Throughput: " << 1000 * static_cast<double>(FLAGS_ni) * batchSize / total << " FPS" << std::endl;
334 std::cout << std::endl;
336 /** Show performance results **/
337 std::map<std::string, InferenceEngineProfileInfo> performanceMap;
339 for (size_t nireq = 0; nireq < FLAGS_nireq; nireq++) {
340 performanceMap = inferRequests[nireq].GetPerformanceCounts();
341 printPerformanceCounts(performanceMap, std::cout);
345 catch (const std::exception& error) {
346 slog::err << error.what() << slog::endl;
350 slog::err << "Unknown/internal exception happened." << slog::endl;
354 slog::info << "Execution successful" << slog::endl;