Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / samples / hello_autoresize_classification / main.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <iomanip>
6 #include <vector>
7 #include <memory>
8 #include <string>
9 #include <cstdlib>
10
11 #include <inference_engine.hpp>
12 #include <samples/ocv_common.hpp>
13 #include <samples/classification_results.h>
14
15 using namespace InferenceEngine;
16
17 int main(int argc, char *argv[]) {
18     try {
19         // ------------------------------ Parsing and validation of input args ---------------------------------
20         if (argc != 4) {
21             std::cout << "Usage : ./hello_autoresize_classification <path_to_model> <path_to_image> <device_name>"
22                       << std::endl;
23             return EXIT_FAILURE;
24         }
25
26         const std::string input_model{argv[1]};
27         const std::string input_image_path{argv[2]};
28         const std::string device_name{argv[3]};
29         // -----------------------------------------------------------------------------------------------------
30
31         // --------------------------- 1. Load Plugin for inference engine -------------------------------------
32         InferencePlugin plugin = PluginDispatcher().getPluginByDevice(device_name);
33         // -----------------------------------------------------------------------------------------------------
34
35         // --------------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ------------
36         size_t batchSize = 1;
37         CNNNetReader network_reader;
38         network_reader.ReadNetwork(input_model);
39         network_reader.ReadWeights(input_model.substr(0, input_model.size() - 4) + ".bin");
40         network_reader.getNetwork().setBatchSize(batchSize);
41         CNNNetwork network = network_reader.getNetwork();
42         // -----------------------------------------------------------------------------------------------------
43
44         // --------------------------- 3. Configure input & output ---------------------------------------------
45         // --------------------------- Prepare input blobs -----------------------------------------------------
46         InputInfo::Ptr input_info = network.getInputsInfo().begin()->second;
47         std::string input_name = network.getInputsInfo().begin()->first;
48
49         /* Mark input as resizable by setting of a resize algorithm.
50          * In this case we will be able to set an input blob of any shape to an infer request.
51          * Resize and layout conversions are executed automatically during inference */
52         input_info->getPreProcess().setResizeAlgorithm(RESIZE_BILINEAR);
53         input_info->setLayout(Layout::NHWC);
54         input_info->setPrecision(Precision::U8);
55
56         // --------------------------- Prepare output blobs ----------------------------------------------------
57         DataPtr output_info = network.getOutputsInfo().begin()->second;
58         std::string output_name = network.getOutputsInfo().begin()->first;
59
60         output_info->setPrecision(Precision::FP32);
61         // -----------------------------------------------------------------------------------------------------
62
63         // --------------------------- 4. Loading model to the plugin ------------------------------------------
64         ExecutableNetwork executable_network = plugin.LoadNetwork(network, {});
65         // -----------------------------------------------------------------------------------------------------
66
67         // --------------------------- 5. Create infer request -------------------------------------------------
68         InferRequest infer_request = executable_network.CreateInferRequest();
69         // -----------------------------------------------------------------------------------------------------
70
71         // --------------------------- 6. Prepare input --------------------------------------------------------
72         /* Read input image to a blob and set it to an infer request without resize and layout conversions. */
73         cv::Mat image = cv::imread(input_image_path);
74         Blob::Ptr imgBlob = wrapMat2Blob(image);  // just wrap Mat data by Blob::Ptr without allocating of new memory
75         infer_request.SetBlob(input_name, imgBlob);  // infer_request accepts input blob of any size
76         // -----------------------------------------------------------------------------------------------------
77
78         // --------------------------- 7. Do inference --------------------------------------------------------
79         typedef std::chrono::high_resolution_clock Time;
80         typedef std::chrono::duration<double, std::ratio<1, 1000>> ms;
81
82         double total = 0.0;
83
84         /* Running the request synchronously */
85         auto t0 = Time::now();
86         infer_request.Infer();  // input pre-processing is invoked on this step with resize and layout conversion
87         auto t1 = Time::now();
88         ms d = std::chrono::duration_cast<ms>(t1 - t0);
89         total += d.count();
90         // -----------------------------------------------------------------------------------------------------
91
92         // --------------------------- 8. Process output ------------------------------------------------------
93         Blob::Ptr output = infer_request.GetBlob(output_name);
94         // Print classification results
95         ClassificationResult classificationResult(output, {input_image_path});
96         classificationResult.print();
97         // -----------------------------------------------------------------------------------------------------
98
99         std::cout << std::endl << "total inference time: " << total << std::endl;
100         std::cout << std::endl << "Throughput: " << 1000 * batchSize / total << " FPS" << std::endl;
101         std::cout << std::endl;
102     } catch (const std::exception & ex) {
103         std::cerr << ex.what() << std::endl;
104         return EXIT_FAILURE;
105     }
106     return EXIT_SUCCESS;
107 }