Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / samples / hello_request_classification / main.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <iomanip>
6 #include <vector>
7 #include <memory>
8 #include <string>
9 #include <cstdlib>
10
11 #include <opencv2/opencv.hpp>
12 #include <inference_engine.hpp>
13 #include <samples/classification_results.h>
14
15 using namespace InferenceEngine;
16
17 int main(int argc, char *argv[]) {
18     try {
19         // ------------------------------ Parsing and validation of input args ---------------------------------
20         if (argc != 4) {
21             std::cout << "Usage : ./hello_request_classification <path_to_model> <path_to_image> <device_name>"
22                       << std::endl;
23             return EXIT_FAILURE;
24         }
25
26         const std::string input_model{argv[1]};
27         const std::string input_image_path{argv[2]};
28         const std::string device_name{argv[3]};
29         // -----------------------------------------------------------------------------------------------------
30
31         // --------------------------- 1. Load Plugin for inference engine -------------------------------------
32         InferencePlugin plugin = PluginDispatcher().getPluginByDevice(device_name);
33         // -----------------------------------------------------------------------------------------------------
34
35         // --------------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ------------
36         CNNNetReader network_reader;
37         network_reader.ReadNetwork(input_model);
38         network_reader.ReadWeights(input_model.substr(0, input_model.size() - 4) + ".bin");
39         network_reader.getNetwork().setBatchSize(1);
40         CNNNetwork network = network_reader.getNetwork();
41         // -----------------------------------------------------------------------------------------------------
42
43         // --------------------------- 3. Configure input & output ---------------------------------------------
44
45         // --------------------------- Prepare input blobs -----------------------------------------------------
46         /** Taking information about all topology inputs **/
47         InputsDataMap input_info(network.getInputsInfo());
48         /** Iterating over all input info**/
49         for (auto &item : input_info) {
50             InputInfo::Ptr input_data = item.second;
51             input_data->setPrecision(Precision::U8);
52             input_data->setLayout(Layout::NCHW);
53         }
54
55         // ------------------------------ Prepare output blobs -------------------------------------------------
56         /** Taking information about all topology outputs **/
57         OutputsDataMap output_info(network.getOutputsInfo());
58         /** Iterating over all output info**/
59         for (auto &item : output_info) {
60             DataPtr output_data = item.second;
61             if (!output_data) {
62                 throw std::runtime_error("Output data pointer is invalid");
63             }
64             output_data->setPrecision(Precision::FP32);
65         }
66         // -----------------------------------------------------------------------------------------------------
67
68         // --------------------------- 4. Loading model to the plugin ------------------------------------------
69         ExecutableNetwork executable_network = plugin.LoadNetwork(network, {});
70         // -----------------------------------------------------------------------------------------------------
71
72         // --------------------------- 5. Create infer request -------------------------------------------------
73         InferRequest async_infer_request = executable_network.CreateInferRequest();
74         // -----------------------------------------------------------------------------------------------------
75
76         // --------------------------- 6. Prepare input --------------------------------------------------------
77         for (auto &item : input_info) {
78             cv::Mat image = cv::imread(input_image_path);
79
80             auto input_name = item.first;
81             InputInfo::Ptr input_data = item.second;
82
83             /** Getting input blob **/
84             Blob::Ptr input = async_infer_request.GetBlob(input_name);
85             auto input_buffer = input->buffer().as<PrecisionTrait<Precision::U8>::value_type *>();
86
87             /** Fill input tensor with planes. First b channel, then g and r channels **/
88             if (image.empty()) throw std::logic_error("Invalid image at path: " + input_image_path);
89
90             /* Resize and copy data from the image to the input blob */
91             cv::resize(image, image, cv::Size(input_data->getTensorDesc().getDims()[3], input_data->getTensorDesc().getDims()[2]));
92             auto dims = input->getTensorDesc().getDims();
93             size_t channels_number = dims[1];
94             size_t image_size = dims[3] * dims[2];
95             for (size_t pid = 0; pid < image_size; ++pid) {
96                 for (size_t ch = 0; ch < channels_number; ++ch) {
97                     input_buffer[ch * image_size + pid] = image.at<cv::Vec3b>(pid)[ch];
98                 }
99             }
100         }
101         // -----------------------------------------------------------------------------------------------------
102
103         // --------------------------- 7. Do inference ---------------------------------------------------------
104         const int max_number_of_iterations = 10;
105         int iterations = max_number_of_iterations;
106         /** Set callback function for calling on completion of async request **/
107         async_infer_request.SetCompletionCallback(
108                 [&] {
109                     std::cout << "Completed " << max_number_of_iterations - iterations + 1 << " async request"
110                               << std::endl;
111                     if (--iterations) {
112                         /** Start async request (max_number_of_iterations - 1) more times **/
113                         async_infer_request.StartAsync();
114                     }
115                 });
116         /** Start async request for the first time **/
117         async_infer_request.StartAsync();
118         /** Wait all repetition of async requests **/
119         for (int i = 0; i < max_number_of_iterations; i++) {
120             async_infer_request.Wait(IInferRequest::WaitMode::RESULT_READY);
121         }
122         // -----------------------------------------------------------------------------------------------------
123
124         // --------------------------- 8. Process output -------------------------------------------------------
125         for (auto &item : output_info) {
126             auto output_name = item.first;
127             Blob::Ptr output = async_infer_request.GetBlob(output_name);;
128             // Print classification results
129             ClassificationResult classificationResult(output, {input_image_path});
130             classificationResult.print();
131         }
132         // -----------------------------------------------------------------------------------------------------
133     } catch (const std::exception & ex) {
134         std::cerr << ex.what() << std::endl;
135         return EXIT_FAILURE;
136     }
137     return EXIT_SUCCESS;
138 }