Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / samples / style_transfer_sample / main.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 /**
6  * @brief The entry point for inference engine deconvolution sample application
7  * @file style_transfer_sample/main.cpp
8  * @example style_transfer_sample/main.cpp
9  */
10 #include <fstream>
11 #include <iomanip>
12 #include <vector>
13 #include <string>
14 #include <chrono>
15 #include <memory>
16
17 #include <format_reader_ptr.h>
18 #include <inference_engine.hpp>
19 #include <ext_list.hpp>
20
21 #include <samples/common.hpp>
22 #include <samples/slog.hpp>
23 #include <samples/args_helper.hpp>
24
25 #include "style_transfer_sample.h"
26
27 using namespace InferenceEngine;
28
29 bool ParseAndCheckCommandLine(int argc, char *argv[]) {
30     // ---------------------------Parsing and validation of input args--------------------------------------
31     slog::info << "Parsing input parameters" << slog::endl;
32
33     gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
34     if (FLAGS_h) {
35         showUsage();
36         return false;
37     }
38
39     if (FLAGS_ni < 1) {
40         throw std::logic_error("Parameter -ni should be more than 0 !!! (default 1)");
41     }
42
43     if (FLAGS_i.empty()) {
44         throw std::logic_error("Parameter -i is not set");
45     }
46
47     if (FLAGS_m.empty()) {
48         throw std::logic_error("Parameter -m is not set");
49     }
50
51     return true;
52 }
53
54 int main(int argc, char *argv[]) {
55     try {
56         slog::info << "InferenceEngine: " << GetInferenceEngineVersion() << slog::endl;
57         // ------------------------------ Parsing and validation of input args ---------------------------------
58         if (!ParseAndCheckCommandLine(argc, argv)) {
59             return 0;
60         }
61
62         /** This vector stores paths to the processed images **/
63         std::vector<std::string> imageNames;
64         parseInputFilesArguments(imageNames);
65         if (imageNames.empty()) throw std::logic_error("No suitable images were found");
66         // -----------------------------------------------------------------------------------------------------
67
68         // --------------------------- 1. Load Plugin for inference engine -------------------------------------
69         slog::info << "Loading plugin" << slog::endl;
70         InferencePlugin plugin = PluginDispatcher({FLAGS_pp}).getPluginByDevice(FLAGS_d);
71
72         /** Printing plugin version **/
73         printPluginVersion(plugin, std::cout);
74
75         /** Loading default extensions **/
76         if (FLAGS_d.find("CPU") != std::string::npos) {
77             /**
78              * cpu_extensions library is compiled from "extension" folder containing
79              * custom MKLDNNPlugin layer implementations. These layers are not supported
80              * by mkldnn, but they can be useful for inferring custom topologies.
81             **/
82             plugin.AddExtension(std::make_shared<Extensions::Cpu::CpuExtensions>());
83         }
84
85         if (!FLAGS_l.empty()) {
86             // CPU(MKLDNN) extensions are loaded as a shared library and passed as a pointer to base extension
87             IExtensionPtr extension_ptr = make_so_pointer<IExtension>(FLAGS_l);
88             plugin.AddExtension(extension_ptr);
89             slog::info << "CPU Extension loaded: " << FLAGS_l << slog::endl;
90         }
91         if (!FLAGS_c.empty()) {
92             // clDNN Extensions are loaded from an .xml description and OpenCL kernel files
93             plugin.SetConfig({{PluginConfigParams::KEY_CONFIG_FILE, FLAGS_c}});
94             slog::info << "GPU Extension loaded: " << FLAGS_c << slog::endl;
95         }
96         // -----------------------------------------------------------------------------------------------------
97
98         // --------------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ------------
99         slog::info << "Loading network files" << slog::endl;
100
101         CNNNetReader networkReader;
102         /** Read network model **/
103         networkReader.ReadNetwork(FLAGS_m);
104
105         /** Extract model name and load weights **/
106         std::string binFileName = fileNameNoExt(FLAGS_m) + ".bin";
107         networkReader.ReadWeights(binFileName);
108         CNNNetwork network = networkReader.getNetwork();
109         // -----------------------------------------------------------------------------------------------------
110
111         // --------------------------- 3. Configure input & output ---------------------------------------------
112
113         // --------------------------- Prepare input blobs -----------------------------------------------------
114         slog::info << "Preparing input blobs" << slog::endl;
115
116         /** Taking information about all topology inputs **/
117         InputsDataMap inputInfo(network.getInputsInfo());
118
119         if (inputInfo.size() != 1) throw std::logic_error("Sample supports topologies only with 1 input");
120         auto inputInfoItem = *inputInfo.begin();
121
122         /** Iterate over all the input blobs **/
123         std::vector<std::shared_ptr<uint8_t>> imagesData;
124
125         /** Specifying the precision of input data.
126          * This should be called before load of the network to the plugin **/
127         inputInfoItem.second->setPrecision(Precision::FP32);
128
129         /** Collect images data ptrs **/
130         for (auto & i : imageNames) {
131             FormatReader::ReaderPtr reader(i.c_str());
132             if (reader.get() == nullptr) {
133                 slog::warn << "Image " + i + " cannot be read!" << slog::endl;
134                 continue;
135             }
136             /** Store image data **/
137             std::shared_ptr<unsigned char> data(reader->getData(inputInfoItem.second->getTensorDesc().getDims()[3],
138                                                                 inputInfoItem.second->getTensorDesc().getDims()[2]));
139             if (data.get() != nullptr) {
140                 imagesData.push_back(data);
141             }
142         }
143         if (imagesData.empty()) throw std::logic_error("Valid input images were not found!");
144
145         /** Setting batch size using image count **/
146         network.setBatchSize(imagesData.size());
147         slog::info << "Batch size is " << std::to_string(network.getBatchSize()) << slog::endl;
148
149         // ------------------------------ Prepare output blobs -------------------------------------------------
150         slog::info << "Preparing output blobs" << slog::endl;
151
152         OutputsDataMap outputInfo(network.getOutputsInfo());
153         // BlobMap outputBlobs;
154         std::string firstOutputName;
155
156         const float meanValues[] = {static_cast<const float>(FLAGS_mean_val_r),
157                                     static_cast<const float>(FLAGS_mean_val_g),
158                                     static_cast<const float>(FLAGS_mean_val_b)};
159
160         for (auto & item : outputInfo) {
161             if (firstOutputName.empty()) {
162                 firstOutputName = item.first;
163             }
164             DataPtr outputData = item.second;
165             if (!outputData) {
166                 throw std::logic_error("output data pointer is not valid");
167             }
168
169             item.second->setPrecision(Precision::FP32);
170         }
171         // -----------------------------------------------------------------------------------------------------
172
173         // --------------------------- 4. Loading model to the plugin ------------------------------------------
174         slog::info << "Loading model to the plugin" << slog::endl;
175         ExecutableNetwork executable_network = plugin.LoadNetwork(network, {});
176         // -----------------------------------------------------------------------------------------------------
177
178         // --------------------------- 5. Create infer request -------------------------------------------------
179         InferRequest infer_request = executable_network.CreateInferRequest();
180         // -----------------------------------------------------------------------------------------------------
181
182         // --------------------------- 6. Prepare input --------------------------------------------------------
183         /** Iterate over all the input blobs **/
184         for (const auto & item : inputInfo) {
185             Blob::Ptr input = infer_request.GetBlob(item.first);
186             /** Filling input tensor with images. First b channel, then g and r channels **/
187             size_t num_channels = input->getTensorDesc().getDims()[1];
188             size_t image_size = input->getTensorDesc().getDims()[3] * input->getTensorDesc().getDims()[2];
189
190             auto data = input->buffer().as<PrecisionTrait<Precision::FP32>::value_type*>();
191
192             /** Iterate over all input images **/
193             for (size_t image_id = 0; image_id < imagesData.size(); ++image_id) {
194                 /** Iterate over all pixel in image (b,g,r) **/
195                 for (size_t pid = 0; pid < image_size; pid++) {
196                     /** Iterate over all channels **/
197                     for (size_t ch = 0; ch < num_channels; ++ch) {
198                         /**          [images stride + channels stride + pixel id ] all in bytes            **/
199                         data[image_id * image_size * num_channels + ch * image_size + pid ] =
200                             imagesData.at(image_id).get()[pid*num_channels + ch] - meanValues[ch];
201                     }
202                 }
203             }
204         }
205         // -----------------------------------------------------------------------------------------------------
206
207         // --------------------------- 7. Do inference ---------------------------------------------------------
208         slog::info << "Start inference (" << FLAGS_ni << " iterations)" << slog::endl;
209
210         typedef std::chrono::high_resolution_clock Time;
211         typedef std::chrono::duration<double, std::ratio<1, 1000>> ms;
212         typedef std::chrono::duration<float> fsec;
213
214         double total = 0.0;
215         /** Start inference & calc performance **/
216         for (size_t iter = 0; iter < FLAGS_ni; ++iter) {
217             auto t0 = Time::now();
218             infer_request.Infer();
219             auto t1 = Time::now();
220             fsec fs = t1 - t0;
221             ms d = std::chrono::duration_cast<ms>(fs);
222             total += d.count();
223         }
224
225         /** Show performance results **/
226         std::cout << std::endl << "Average running time of one iteration: " << total / static_cast<double>(FLAGS_ni)
227                   << " ms" << std::endl;
228
229         if (FLAGS_pc) {
230             printPerformanceCounts(infer_request, std::cout);
231         }
232         // -----------------------------------------------------------------------------------------------------
233
234         // --------------------------- 8. Process output -------------------------------------------------------
235         const Blob::Ptr output_blob = infer_request.GetBlob(firstOutputName);
236         const auto output_data = output_blob->buffer().as<PrecisionTrait<Precision::FP32>::value_type*>();
237
238         size_t num_images = output_blob->getTensorDesc().getDims()[0];
239         size_t num_channels = output_blob->getTensorDesc().getDims()[1];
240         size_t H = output_blob->getTensorDesc().getDims()[2];
241         size_t W = output_blob->getTensorDesc().getDims()[3];
242         size_t nPixels = W * H;
243
244         slog::info << "Output size [N,C,H,W]: " << num_images << ", " << num_channels << ", " << H << ", " << W << slog::endl;
245
246         {
247             std::vector<float> data_img(nPixels * num_channels);
248
249             for (size_t n = 0; n < num_images; n++) {
250                 for (size_t i = 0; i < nPixels; i++) {
251                     data_img[i * num_channels] = static_cast<float>(output_data[i + n * nPixels * num_channels] +
252                                                                    meanValues[0]);
253                     data_img[i * num_channels + 1] = static_cast<float>(
254                             output_data[(i + nPixels) + n * nPixels * num_channels] + meanValues[1]);
255                     data_img[i * num_channels + 2] = static_cast<float>(
256                             output_data[(i + 2 * nPixels) + n * nPixels * num_channels] + meanValues[2]);
257
258                     float temp = data_img[i * num_channels];
259                     data_img[i * num_channels] = data_img[i * num_channels + 2];
260                     data_img[i * num_channels + 2] = temp;
261
262                     if (data_img[i * num_channels] < 0) data_img[i * num_channels] = 0;
263                     if (data_img[i * num_channels] > 255) data_img[i * num_channels] = 255;
264
265                     if (data_img[i * num_channels + 1] < 0) data_img[i * num_channels + 1] = 0;
266                     if (data_img[i * num_channels + 1] > 255) data_img[i * num_channels + 1] = 255;
267
268                     if (data_img[i * num_channels + 2] < 0) data_img[i * num_channels + 2] = 0;
269                     if (data_img[i * num_channels + 2] > 255) data_img[i * num_channels + 2] = 255;
270                 }
271                 std::string out_img_name = std::string("out" + std::to_string(n + 1) + ".bmp");
272                 std::ofstream outFile;
273                 outFile.open(out_img_name.c_str(), std::ios_base::binary);
274                 if (!outFile.is_open()) {
275                     throw new std::runtime_error("Cannot create " + out_img_name);
276                 }
277                 std::vector<unsigned char> data_img2;
278                 for (float i : data_img) {
279                     data_img2.push_back(static_cast<unsigned char>(i));
280                 }
281                 writeOutputBmp(data_img2.data(), H, W, outFile);
282                 outFile.close();
283                 slog::info << "Image " << out_img_name << " created!" << slog::endl;
284             }
285         }
286         // -----------------------------------------------------------------------------------------------------
287     }
288     catch (const std::exception &error) {
289         slog::err << error.what() << slog::endl;
290         return 1;
291     }
292     catch (...) {
293         slog::err << "Unknown/internal exception happened" << slog::endl;
294         return 1;
295     }
296
297     slog::info << "Execution successful" << slog::endl;
298     return 0;
299 }