1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
15 #include <opencv2/opencv.hpp>
16 #include <inference_engine.hpp>
17 #include <samples/classification_results.h>
19 using namespace InferenceEngine;
22 #define tcout std::cout
25 #define tcout std::wcout
29 int main(int argc, char *argv[]) {
31 int wmain(int argc, wchar_t *argv[]) {
34 // ------------------------------ Parsing and validation of input args ---------------------------------
36 tcout << _T("Usage : ./hello_classification <path_to_model> <path_to_image>") << std::endl;
40 const file_name_t input_model{argv[1]};
41 const file_name_t input_image_path{argv[2]};
42 // -----------------------------------------------------------------------------------------------------
44 // --------------------------- 1. Load Plugin for inference engine -------------------------------------
45 InferencePlugin plugin(PluginDispatcher().getSuitablePlugin(TargetDevice::eCPU));
46 // -----------------------------------------------------------------------------------------------------
48 // --------------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ------------
49 CNNNetReader network_reader;
50 network_reader.ReadNetwork(fileNameToString(input_model));
51 network_reader.ReadWeights(fileNameToString(input_model).substr(0, input_model.size() - 4) + ".bin");
52 network_reader.getNetwork().setBatchSize(1);
53 CNNNetwork network = network_reader.getNetwork();
54 // -----------------------------------------------------------------------------------------------------
56 // --------------------------- 3. Configure input & output ---------------------------------------------
57 // --------------------------- Prepare input blobs -----------------------------------------------------
58 InputInfo::Ptr input_info = network.getInputsInfo().begin()->second;
59 std::string input_name = network.getInputsInfo().begin()->first;
61 input_info->setLayout(Layout::NCHW);
62 input_info->setPrecision(Precision::U8);
64 // --------------------------- Prepare output blobs ----------------------------------------------------
65 DataPtr output_info = network.getOutputsInfo().begin()->second;
66 std::string output_name = network.getOutputsInfo().begin()->first;
68 output_info->setPrecision(Precision::FP32);
69 // -----------------------------------------------------------------------------------------------------
71 // --------------------------- 4. Loading model to the plugin ------------------------------------------
72 ExecutableNetwork executable_network = plugin.LoadNetwork(network, {});
73 // -----------------------------------------------------------------------------------------------------
75 // --------------------------- 5. Create infer request -------------------------------------------------
76 InferRequest infer_request = executable_network.CreateInferRequest();
77 // -----------------------------------------------------------------------------------------------------
79 // --------------------------- 6. Prepare input --------------------------------------------------------
81 cv::Mat image = cv::imread(fileNameToString(input_image_path));
83 /* Resize manually and copy data from the image to the input blob */
84 Blob::Ptr input = infer_request.GetBlob(input_name);
85 auto input_data = input->buffer().as<PrecisionTrait<Precision::U8>::value_type *>();
87 cv::resize(image, image, cv::Size(input_info->getTensorDesc().getDims()[3], input_info->getTensorDesc().getDims()[2]));
89 size_t channels_number = input->getTensorDesc().getDims()[1];
90 size_t image_size = input->getTensorDesc().getDims()[3] * input->getTensorDesc().getDims()[2];
92 for (size_t pid = 0; pid < image_size; ++pid) {
93 for (size_t ch = 0; ch < channels_number; ++ch) {
94 input_data[ch * image_size + pid] = image.at<cv::Vec3b>(pid)[ch];
97 // -----------------------------------------------------------------------------------------------------
99 // --------------------------- 7. Do inference --------------------------------------------------------
100 /* Running the request synchronously */
101 infer_request.Infer();
102 // -----------------------------------------------------------------------------------------------------
104 // --------------------------- 8. Process output ------------------------------------------------------
105 Blob::Ptr output = infer_request.GetBlob(output_name);
106 // Print classification results
107 ClassificationResult classificationResult(output, {fileNameToString(input_image_path)});
108 classificationResult.print();
110 // -----------------------------------------------------------------------------------------------------
111 } catch (const std::exception & ex) {
112 std::cerr << ex.what() << std::endl;