Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / samples / hello_classification / main.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <iomanip>
6 #include <vector>
7 #include <memory>
8 #include <string>
9 #include <cstdlib>
10
11 #ifdef UNICODE
12 #include <tchar.h>
13 #endif
14
15 #include <opencv2/opencv.hpp>
16 #include <inference_engine.hpp>
17 #include <samples/classification_results.h>
18
19 using namespace InferenceEngine;
20
21 #ifndef UNICODE
22 #define tcout std::cout
23 #define _T(STR) STR
24 #else
25 #define tcout std::wcout
26 #endif
27
28 #ifndef UNICODE
29 int main(int argc, char *argv[]) {
30 #else
31 int wmain(int argc, wchar_t *argv[]) {
32 #endif
33     try {
34         // ------------------------------ Parsing and validation of input args ---------------------------------
35         if (argc != 3) {
36             tcout << _T("Usage : ./hello_classification <path_to_model> <path_to_image>") << std::endl;
37             return EXIT_FAILURE;
38         }
39
40         const file_name_t input_model{argv[1]};
41         const file_name_t input_image_path{argv[2]};
42         // -----------------------------------------------------------------------------------------------------
43
44         // --------------------------- 1. Load Plugin for inference engine -------------------------------------
45         InferencePlugin plugin(PluginDispatcher().getSuitablePlugin(TargetDevice::eCPU));
46         // -----------------------------------------------------------------------------------------------------
47
48         // --------------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ------------
49         CNNNetReader network_reader;
50         network_reader.ReadNetwork(fileNameToString(input_model));
51         network_reader.ReadWeights(fileNameToString(input_model).substr(0, input_model.size() - 4) + ".bin");
52         network_reader.getNetwork().setBatchSize(1);
53         CNNNetwork network = network_reader.getNetwork();
54         // -----------------------------------------------------------------------------------------------------
55
56         // --------------------------- 3. Configure input & output ---------------------------------------------
57         // --------------------------- Prepare input blobs -----------------------------------------------------
58         InputInfo::Ptr input_info = network.getInputsInfo().begin()->second;
59         std::string input_name = network.getInputsInfo().begin()->first;
60
61         input_info->setLayout(Layout::NCHW);
62         input_info->setPrecision(Precision::U8);
63
64         // --------------------------- Prepare output blobs ----------------------------------------------------
65         DataPtr output_info = network.getOutputsInfo().begin()->second;
66         std::string output_name = network.getOutputsInfo().begin()->first;
67
68         output_info->setPrecision(Precision::FP32);
69         // -----------------------------------------------------------------------------------------------------
70
71         // --------------------------- 4. Loading model to the plugin ------------------------------------------
72         ExecutableNetwork executable_network = plugin.LoadNetwork(network, {});
73         // -----------------------------------------------------------------------------------------------------
74
75         // --------------------------- 5. Create infer request -------------------------------------------------
76         InferRequest infer_request = executable_network.CreateInferRequest();
77         // -----------------------------------------------------------------------------------------------------
78
79         // --------------------------- 6. Prepare input --------------------------------------------------------
80
81         cv::Mat image = cv::imread(fileNameToString(input_image_path));
82
83         /* Resize manually and copy data from the image to the input blob */
84         Blob::Ptr input = infer_request.GetBlob(input_name);
85         auto input_data = input->buffer().as<PrecisionTrait<Precision::U8>::value_type *>();
86
87         cv::resize(image, image, cv::Size(input_info->getTensorDesc().getDims()[3], input_info->getTensorDesc().getDims()[2]));
88
89         size_t channels_number = input->getTensorDesc().getDims()[1];
90         size_t image_size = input->getTensorDesc().getDims()[3] * input->getTensorDesc().getDims()[2];
91
92         for (size_t pid = 0; pid < image_size; ++pid) {
93             for (size_t ch = 0; ch < channels_number; ++ch) {
94                 input_data[ch * image_size + pid] = image.at<cv::Vec3b>(pid)[ch];
95             }
96         }
97         // -----------------------------------------------------------------------------------------------------
98
99         // --------------------------- 7. Do inference --------------------------------------------------------
100         /* Running the request synchronously */
101         infer_request.Infer();
102         // -----------------------------------------------------------------------------------------------------
103
104         // --------------------------- 8. Process output ------------------------------------------------------
105         Blob::Ptr output = infer_request.GetBlob(output_name);
106         // Print classification results
107         ClassificationResult classificationResult(output, {fileNameToString(input_image_path)});
108         classificationResult.print();
109
110         // -----------------------------------------------------------------------------------------------------
111     } catch (const std::exception & ex) {
112         std::cerr << ex.what() << std::endl;
113         return EXIT_FAILURE;
114     }
115     return EXIT_SUCCESS;
116 }