1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
6 * @brief The entry point for inference engine deconvolution sample application
7 * @file style_transfer_sample/main.cpp
8 * @example style_transfer_sample/main.cpp
17 #include <format_reader_ptr.h>
18 #include <inference_engine.hpp>
19 #include <ext_list.hpp>
21 #include <samples/common.hpp>
22 #include <samples/slog.hpp>
23 #include <samples/args_helper.hpp>
25 #include "style_transfer_sample.h"
27 using namespace InferenceEngine;
29 bool ParseAndCheckCommandLine(int argc, char *argv[]) {
30 // ---------------------------Parsing and validation of input args--------------------------------------
31 slog::info << "Parsing input parameters" << slog::endl;
33 gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
40 throw std::logic_error("Parameter -ni should be more than 0 !!! (default 1)");
43 if (FLAGS_i.empty()) {
44 throw std::logic_error("Parameter -i is not set");
47 if (FLAGS_m.empty()) {
48 throw std::logic_error("Parameter -m is not set");
54 int main(int argc, char *argv[]) {
56 slog::info << "InferenceEngine: " << GetInferenceEngineVersion() << slog::endl;
57 // ------------------------------ Parsing and validation of input args ---------------------------------
58 if (!ParseAndCheckCommandLine(argc, argv)) {
62 /** This vector stores paths to the processed images **/
63 std::vector<std::string> imageNames;
64 parseInputFilesArguments(imageNames);
65 if (imageNames.empty()) throw std::logic_error("No suitable images were found");
66 // -----------------------------------------------------------------------------------------------------
68 // --------------------------- 1. Load Plugin for inference engine -------------------------------------
69 slog::info << "Loading plugin" << slog::endl;
70 InferencePlugin plugin = PluginDispatcher({FLAGS_pp}).getPluginByDevice(FLAGS_d);
72 /** Printing plugin version **/
73 printPluginVersion(plugin, std::cout);
75 /** Loading default extensions **/
76 if (FLAGS_d.find("CPU") != std::string::npos) {
78 * cpu_extensions library is compiled from "extension" folder containing
79 * custom MKLDNNPlugin layer implementations. These layers are not supported
80 * by mkldnn, but they can be useful for inferring custom topologies.
82 plugin.AddExtension(std::make_shared<Extensions::Cpu::CpuExtensions>());
85 if (!FLAGS_l.empty()) {
86 // CPU(MKLDNN) extensions are loaded as a shared library and passed as a pointer to base extension
87 IExtensionPtr extension_ptr = make_so_pointer<IExtension>(FLAGS_l);
88 plugin.AddExtension(extension_ptr);
89 slog::info << "CPU Extension loaded: " << FLAGS_l << slog::endl;
91 if (!FLAGS_c.empty()) {
92 // clDNN Extensions are loaded from an .xml description and OpenCL kernel files
93 plugin.SetConfig({{PluginConfigParams::KEY_CONFIG_FILE, FLAGS_c}});
94 slog::info << "GPU Extension loaded: " << FLAGS_c << slog::endl;
96 // -----------------------------------------------------------------------------------------------------
98 // --------------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ------------
99 slog::info << "Loading network files" << slog::endl;
101 CNNNetReader networkReader;
102 /** Read network model **/
103 networkReader.ReadNetwork(FLAGS_m);
105 /** Extract model name and load weights **/
106 std::string binFileName = fileNameNoExt(FLAGS_m) + ".bin";
107 networkReader.ReadWeights(binFileName);
108 CNNNetwork network = networkReader.getNetwork();
109 // -----------------------------------------------------------------------------------------------------
111 // --------------------------- 3. Configure input & output ---------------------------------------------
113 // --------------------------- Prepare input blobs -----------------------------------------------------
114 slog::info << "Preparing input blobs" << slog::endl;
116 /** Taking information about all topology inputs **/
117 InputsDataMap inputInfo(network.getInputsInfo());
119 if (inputInfo.size() != 1) throw std::logic_error("Sample supports topologies only with 1 input");
120 auto inputInfoItem = *inputInfo.begin();
122 /** Iterate over all the input blobs **/
123 std::vector<std::shared_ptr<uint8_t>> imagesData;
125 /** Specifying the precision of input data.
126 * This should be called before load of the network to the plugin **/
127 inputInfoItem.second->setPrecision(Precision::FP32);
129 /** Collect images data ptrs **/
130 for (auto & i : imageNames) {
131 FormatReader::ReaderPtr reader(i.c_str());
132 if (reader.get() == nullptr) {
133 slog::warn << "Image " + i + " cannot be read!" << slog::endl;
136 /** Store image data **/
137 std::shared_ptr<unsigned char> data(reader->getData(inputInfoItem.second->getTensorDesc().getDims()[3],
138 inputInfoItem.second->getTensorDesc().getDims()[2]));
139 if (data.get() != nullptr) {
140 imagesData.push_back(data);
143 if (imagesData.empty()) throw std::logic_error("Valid input images were not found!");
145 /** Setting batch size using image count **/
146 network.setBatchSize(imagesData.size());
147 slog::info << "Batch size is " << std::to_string(network.getBatchSize()) << slog::endl;
149 // ------------------------------ Prepare output blobs -------------------------------------------------
150 slog::info << "Preparing output blobs" << slog::endl;
152 OutputsDataMap outputInfo(network.getOutputsInfo());
153 // BlobMap outputBlobs;
154 std::string firstOutputName;
156 const float meanValues[] = {static_cast<const float>(FLAGS_mean_val_r),
157 static_cast<const float>(FLAGS_mean_val_g),
158 static_cast<const float>(FLAGS_mean_val_b)};
160 for (auto & item : outputInfo) {
161 if (firstOutputName.empty()) {
162 firstOutputName = item.first;
164 DataPtr outputData = item.second;
166 throw std::logic_error("output data pointer is not valid");
169 item.second->setPrecision(Precision::FP32);
171 // -----------------------------------------------------------------------------------------------------
173 // --------------------------- 4. Loading model to the plugin ------------------------------------------
174 slog::info << "Loading model to the plugin" << slog::endl;
175 ExecutableNetwork executable_network = plugin.LoadNetwork(network, {});
176 // -----------------------------------------------------------------------------------------------------
178 // --------------------------- 5. Create infer request -------------------------------------------------
179 InferRequest infer_request = executable_network.CreateInferRequest();
180 // -----------------------------------------------------------------------------------------------------
182 // --------------------------- 6. Prepare input --------------------------------------------------------
183 /** Iterate over all the input blobs **/
184 for (const auto & item : inputInfo) {
185 Blob::Ptr input = infer_request.GetBlob(item.first);
186 /** Filling input tensor with images. First b channel, then g and r channels **/
187 size_t num_channels = input->getTensorDesc().getDims()[1];
188 size_t image_size = input->getTensorDesc().getDims()[3] * input->getTensorDesc().getDims()[2];
190 auto data = input->buffer().as<PrecisionTrait<Precision::FP32>::value_type*>();
192 /** Iterate over all input images **/
193 for (size_t image_id = 0; image_id < imagesData.size(); ++image_id) {
194 /** Iterate over all pixel in image (b,g,r) **/
195 for (size_t pid = 0; pid < image_size; pid++) {
196 /** Iterate over all channels **/
197 for (size_t ch = 0; ch < num_channels; ++ch) {
198 /** [images stride + channels stride + pixel id ] all in bytes **/
199 data[image_id * image_size * num_channels + ch * image_size + pid ] =
200 imagesData.at(image_id).get()[pid*num_channels + ch] - meanValues[ch];
205 // -----------------------------------------------------------------------------------------------------
207 // --------------------------- 7. Do inference ---------------------------------------------------------
208 slog::info << "Start inference (" << FLAGS_ni << " iterations)" << slog::endl;
210 typedef std::chrono::high_resolution_clock Time;
211 typedef std::chrono::duration<double, std::ratio<1, 1000>> ms;
212 typedef std::chrono::duration<float> fsec;
215 /** Start inference & calc performance **/
216 for (size_t iter = 0; iter < FLAGS_ni; ++iter) {
217 auto t0 = Time::now();
218 infer_request.Infer();
219 auto t1 = Time::now();
221 ms d = std::chrono::duration_cast<ms>(fs);
225 /** Show performance results **/
226 std::cout << std::endl << "Average running time of one iteration: " << total / static_cast<double>(FLAGS_ni)
227 << " ms" << std::endl;
230 printPerformanceCounts(infer_request, std::cout);
232 // -----------------------------------------------------------------------------------------------------
234 // --------------------------- 8. Process output -------------------------------------------------------
235 const Blob::Ptr output_blob = infer_request.GetBlob(firstOutputName);
236 const auto output_data = output_blob->buffer().as<PrecisionTrait<Precision::FP32>::value_type*>();
238 size_t num_images = output_blob->getTensorDesc().getDims()[0];
239 size_t num_channels = output_blob->getTensorDesc().getDims()[1];
240 size_t H = output_blob->getTensorDesc().getDims()[2];
241 size_t W = output_blob->getTensorDesc().getDims()[3];
242 size_t nPixels = W * H;
244 slog::info << "Output size [N,C,H,W]: " << num_images << ", " << num_channels << ", " << H << ", " << W << slog::endl;
247 std::vector<float> data_img(nPixels * num_channels);
249 for (size_t n = 0; n < num_images; n++) {
250 for (size_t i = 0; i < nPixels; i++) {
251 data_img[i * num_channels] = static_cast<float>(output_data[i + n * nPixels * num_channels] +
253 data_img[i * num_channels + 1] = static_cast<float>(
254 output_data[(i + nPixels) + n * nPixels * num_channels] + meanValues[1]);
255 data_img[i * num_channels + 2] = static_cast<float>(
256 output_data[(i + 2 * nPixels) + n * nPixels * num_channels] + meanValues[2]);
258 float temp = data_img[i * num_channels];
259 data_img[i * num_channels] = data_img[i * num_channels + 2];
260 data_img[i * num_channels + 2] = temp;
262 if (data_img[i * num_channels] < 0) data_img[i * num_channels] = 0;
263 if (data_img[i * num_channels] > 255) data_img[i * num_channels] = 255;
265 if (data_img[i * num_channels + 1] < 0) data_img[i * num_channels + 1] = 0;
266 if (data_img[i * num_channels + 1] > 255) data_img[i * num_channels + 1] = 255;
268 if (data_img[i * num_channels + 2] < 0) data_img[i * num_channels + 2] = 0;
269 if (data_img[i * num_channels + 2] > 255) data_img[i * num_channels + 2] = 255;
271 std::string out_img_name = std::string("out" + std::to_string(n + 1) + ".bmp");
272 std::ofstream outFile;
273 outFile.open(out_img_name.c_str(), std::ios_base::binary);
274 if (!outFile.is_open()) {
275 throw new std::runtime_error("Cannot create " + out_img_name);
277 std::vector<unsigned char> data_img2;
278 for (float i : data_img) {
279 data_img2.push_back(static_cast<unsigned char>(i));
281 writeOutputBmp(data_img2.data(), H, W, outFile);
283 slog::info << "Image " << out_img_name << " created!" << slog::endl;
286 // -----------------------------------------------------------------------------------------------------
288 catch (const std::exception &error) {
289 slog::err << error.what() << slog::endl;
293 slog::err << "Unknown/internal exception happened" << slog::endl;
297 slog::info << "Execution successful" << slog::endl;