Publishing R5 content (#72)
[platform/upstream/dldt.git] / inference-engine / samples / speech_sample / main.cpp
1 // Copyright (C) 2018 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "speech_sample.hpp"
6
7 #include <gflags/gflags.h>
8 #include <functional>
9 #include <iostream>
10 #include <memory>
11 #include <map>
12 #include <fstream>
13 #include <random>
14 #include <string>
15 #include <vector>
16 #include <utility>
17 #include <time.h>
18 #include <thread>
19 #include <chrono>
20 #include <limits>
21 #include <iomanip>
22 #include <inference_engine.hpp>
23 #include <gna/gna_config.hpp>
24
25 #include <samples/common.hpp>
26 #include <samples/slog.hpp>
27 #include <samples/args_helper.hpp>
28
29 #ifndef ALIGN
30 #define ALIGN(memSize, pad)   ((static_cast<int>((memSize) + pad - 1) / pad) * pad)
31 #endif
32 #define MAX_SCORE_DIFFERENCE 0.0001f
33 #define MAX_VAL_2B_FEAT 16384
34
35 using namespace InferenceEngine;
36
37 typedef std::chrono::high_resolution_clock Time;
38 typedef std::chrono::duration<double, std::ratio<1, 1000>> ms;
39 typedef std::chrono::duration<float> fsec;
40 typedef struct {
41     uint32_t numScores;
42     uint32_t numErrors;
43     float threshold;
44     float maxError;
45     float rmsError;
46     float sumError;
47     float sumRmsError;
48     float sumSquaredError;
49     float maxRelError;
50     float sumRelError;
51     float sumSquaredRelError;
52 } score_error_t;
53
54 void GetKaldiArkInfo(const char *fileName,
55                      uint32_t numArrayToFindSize,
56                      uint32_t *ptrNumArrays,
57                      uint32_t *ptrNumMemoryBytes) {
58     uint32_t numArrays = 0;
59     uint32_t numMemoryBytes = 0;
60
61     std::ifstream in_file(fileName, std::ios::binary);
62     if (in_file.good()) {
63         while (!in_file.eof()) {
64             std::string line;
65             uint32_t numRows = 0u, numCols = 0u, num_bytes = 0u;
66             std::getline(in_file, line, '\0');  // read variable length name followed by space and NUL
67             std::getline(in_file, line, '\4');  // read "BFM" followed by space and control-D
68             if (line.compare("BFM ") != 0) {
69                 break;
70             }
71             in_file.read(reinterpret_cast<char *>(&numRows), sizeof(uint32_t));  // read number of rows
72             std::getline(in_file, line, '\4');                                   // read control-D
73             in_file.read(reinterpret_cast<char *>(&numCols), sizeof(uint32_t));  // read number of columns
74             num_bytes = numRows * numCols * sizeof(float);
75             in_file.seekg(num_bytes, in_file.cur);                               // read data
76
77             if (numArrays == numArrayToFindSize) {
78                 numMemoryBytes += num_bytes;
79             }
80             numArrays++;
81         }
82         in_file.close();
83     } else {
84         fprintf(stderr, "Failed to open %s for reading in GetKaldiArkInfo()!\n", fileName);
85         exit(-1);
86     }
87
88     if (ptrNumArrays != NULL) *ptrNumArrays = numArrays;
89     if (ptrNumMemoryBytes != NULL) *ptrNumMemoryBytes = numMemoryBytes;
90 }
91
92 void LoadKaldiArkArray(const char *fileName, uint32_t arrayIndex, std::string &ptrName, std::vector<uint8_t> &memory,
93                        uint32_t *ptrNumRows, uint32_t *ptrNumColumns, uint32_t *ptrNumBytesPerElement) {
94     std::ifstream in_file(fileName, std::ios::binary);
95     if (in_file.good()) {
96         uint32_t i = 0;
97         while (i < arrayIndex) {
98             std::string line;
99             uint32_t numRows = 0u, numCols = 0u;
100             std::getline(in_file, line, '\0');  // read variable length name followed by space and NUL
101             std::getline(in_file, line, '\4');  // read "BFM" followed by space and control-D
102             if (line.compare("BFM ") != 0) {
103                 break;
104             }
105             in_file.read(reinterpret_cast<char *>(&numRows), sizeof(uint32_t));     // read number of rows
106             std::getline(in_file, line, '\4');                                     // read control-D
107             in_file.read(reinterpret_cast<char *>(&numCols), sizeof(uint32_t));     // read number of columns
108             in_file.seekg(numRows * numCols * sizeof(float), in_file.cur);         // read data
109             i++;
110         }
111         if (!in_file.eof()) {
112             std::string line;
113             std::getline(in_file, ptrName, '\0');     // read variable length name followed by space and NUL
114             std::getline(in_file, line, '\4');       // read "BFM" followed by space and control-D
115             if (line.compare("BFM ") != 0) {
116                 fprintf(stderr, "Cannot find array specifier in file %s in LoadKaldiArkArray()!\n", fileName);
117                 exit(-1);
118             }
119             in_file.read(reinterpret_cast<char *>(ptrNumRows), sizeof(uint32_t));        // read number of rows
120             std::getline(in_file, line, '\4');                                            // read control-D
121             in_file.read(reinterpret_cast<char *>(ptrNumColumns), sizeof(uint32_t));    // read number of columns
122             size_t willWrite = *ptrNumRows * *ptrNumColumns * sizeof(float);
123             in_file.read(reinterpret_cast<char *>(&memory.front()),
124                          *ptrNumRows * *ptrNumColumns * sizeof(float));  // read array data
125         }
126         in_file.close();
127     } else {
128         fprintf(stderr, "Failed to open %s for reading in GetKaldiArkInfo()!\n", fileName);
129         exit(-1);
130     }
131
132     *ptrNumBytesPerElement = sizeof(float);
133 }
134
135 void SaveKaldiArkArray(const char *fileName,
136                        bool shouldAppend,
137                        std::string name,
138                        void *ptrMemory,
139                        uint32_t numRows,
140                        uint32_t numColumns) {
141     std::ios_base::openmode mode = std::ios::binary;
142     if (shouldAppend) {
143         mode |= std::ios::app;
144     }
145     std::ofstream out_file(fileName, mode);
146     if (out_file.good()) {
147         out_file.write(name.c_str(), name.length());  // write name
148         out_file.write("\0", 1);
149         out_file.write("BFM ", 4);
150         out_file.write("\4", 1);
151         out_file.write(reinterpret_cast<char *>(&numRows), sizeof(uint32_t));
152         out_file.write("\4", 1);
153         out_file.write(reinterpret_cast<char *>(&numColumns), sizeof(uint32_t));
154         out_file.write(reinterpret_cast<char *>(ptrMemory), numRows * numColumns * sizeof(float));
155         out_file.close();
156     } else {
157         throw std::runtime_error(std::string("Failed to open %s for writing in SaveKaldiArkArray()!\n") + fileName);
158     }
159 }
160
161 float ScaleFactorForQuantization(void *ptrFloatMemory, float targetMax, uint32_t numElements) {
162     float *ptrFloatFeat = reinterpret_cast<float *>(ptrFloatMemory);
163     float max = 0.0;
164     float scaleFactor;
165
166     for (uint32_t i = 0; i < numElements; i++) {
167         if (fabs(ptrFloatFeat[i]) > max) {
168             max = fabs(ptrFloatFeat[i]);
169         }
170     }
171
172     if (max == 0) {
173         scaleFactor = 1.0;
174     } else {
175         scaleFactor = targetMax / max;
176     }
177
178     return (scaleFactor);
179 }
180
181 void ClearScoreError(score_error_t *error) {
182     error->numScores = 0;
183     error->numErrors = 0;
184     error->maxError = 0.0;
185     error->rmsError = 0.0;
186     error->sumError = 0.0;
187     error->sumRmsError = 0.0;
188     error->sumSquaredError = 0.0;
189     error->maxRelError = 0.0;
190     error->sumRelError = 0.0;
191     error->sumSquaredRelError = 0.0;
192 }
193
194 void UpdateScoreError(score_error_t *error, score_error_t *totalError) {
195     totalError->numErrors += error->numErrors;
196     totalError->numScores += error->numScores;
197     totalError->sumRmsError += error->rmsError;
198     totalError->sumError += error->sumError;
199     totalError->sumSquaredError += error->sumSquaredError;
200     if (error->maxError > totalError->maxError) {
201         totalError->maxError = error->maxError;
202     }
203     totalError->sumRelError += error->sumRelError;
204     totalError->sumSquaredRelError += error->sumSquaredRelError;
205     if (error->maxRelError > totalError->maxRelError) {
206         totalError->maxRelError = error->maxRelError;
207     }
208 }
209
210 uint32_t CompareScores(float *ptrScoreArray,
211                        void *ptrRefScoreArray,
212                        score_error_t *scoreError,
213                        uint32_t numRows,
214                        uint32_t numColumns) {
215     uint32_t numErrors = 0;
216
217     ClearScoreError(scoreError);
218
219     float *A = ptrScoreArray;
220     float *B = reinterpret_cast<float *>(ptrRefScoreArray);
221     for (uint32_t i = 0; i < numRows; i++) {
222         for (uint32_t j = 0; j < numColumns; j++) {
223             float score = A[i * numColumns + j];
224             float refscore = B[i * numColumns + j];
225             float error = fabs(refscore - score);
226             float rel_error = error / (static_cast<float>(fabs(refscore)) + 1e-20f);
227             float squared_error = error * error;
228             float squared_rel_error = rel_error * rel_error;
229             scoreError->numScores++;
230             scoreError->sumError += error;
231             scoreError->sumSquaredError += squared_error;
232             if (error > scoreError->maxError) {
233                 scoreError->maxError = error;
234             }
235             scoreError->sumRelError += rel_error;
236             scoreError->sumSquaredRelError += squared_rel_error;
237             if (rel_error > scoreError->maxRelError) {
238                 scoreError->maxRelError = rel_error;
239             }
240             if (error > scoreError->threshold) {
241                 numErrors++;
242             }
243         }
244     }
245     scoreError->rmsError = sqrt(scoreError->sumSquaredError / (numRows * numColumns));
246     scoreError->sumRmsError += scoreError->rmsError;
247     scoreError->numErrors = numErrors;
248
249     return (numErrors);
250 }
251
252 float StdDevError(score_error_t error) {
253     return (sqrt(error.sumSquaredError / error.numScores
254                  - (error.sumError / error.numScores) * (error.sumError / error.numScores)));
255 }
256
257 float StdDevRelError(score_error_t error) {
258     return (sqrt(error.sumSquaredRelError / error.numScores
259                  - (error.sumRelError / error.numScores) * (error.sumRelError / error.numScores)));
260 }
261
262 #if !defined(__arm__) && !defined(_M_ARM)
263 #if defined(_WIN32) || defined(WIN32)
264 #include <intrin.h>
265 #include <windows.h>
266 #else
267
268 #include <cpuid.h>
269
270 #endif
271
272 inline void native_cpuid(unsigned int *eax, unsigned int *ebx,
273                          unsigned int *ecx, unsigned int *edx) {
274     size_t level = *eax;
275 #if defined(_WIN32) || defined(WIN32)
276     int regs[4] = {static_cast<int>(*eax), static_cast<int>(*ebx), static_cast<int>(*ecx), static_cast<int>(*edx)};
277     __cpuid(regs, level);
278     *eax = static_cast<uint32_t>(regs[0]);
279     *ebx = static_cast<uint32_t>(regs[1]);
280     *ecx = static_cast<uint32_t>(regs[2]);
281     *edx = static_cast<uint32_t>(regs[3]);
282 #else
283     __get_cpuid(level, eax, ebx, ecx, edx);
284 #endif
285 }
286
287 // return GNA module frequency in MHz
288 float getGnaFrequencyMHz() {
289     uint32_t level = 0;
290     uint32_t eax = 1;
291     uint32_t ebx = 0;
292     uint32_t ecx = 0;
293     uint32_t edx = 0;
294     uint32_t family = 0;
295     uint32_t model = 0;
296     const uint8_t sixth_family = 6;
297     const uint8_t cannon_lake_model = 102;
298     const uint8_t gemini_lake_model = 122;
299
300     native_cpuid(&eax, &ebx, &ecx, &edx);
301     family = (eax >> 8) & 0xF;
302
303     // model is the concatenation of two fields
304     // | extended model | model |
305     // copy extended model data
306     model = (eax >> 16) & 0xF;
307     // shift
308     model <<= 4;
309     // copy model data
310     model += (eax >> 4) & 0xF;
311
312     if (family == sixth_family && model == cannon_lake_model) {
313         return 400;
314     } else if (family == sixth_family &&
315                model == gemini_lake_model) {
316         return 200;
317     } else {
318         // counters not supported and we retrns just default value
319         return 1;
320     }
321 }
322
323 #endif  // !defined(__arm__) && !defined(_M_ARM)
324
325 void printReferenceCompareResults(score_error_t const &totalError,
326                                   size_t framesNum,
327                                   std::ostream &stream) {
328     stream << "         max error: " <<
329            totalError.maxError << std::endl;
330     stream << "         avg error: " <<
331            totalError.sumError / totalError.numScores << std::endl;
332     stream << "     avg rms error: " <<
333            totalError.sumRmsError / framesNum << std::endl;
334     stream << "       stdev error: " <<
335            StdDevError(totalError) << std::endl << std::endl;
336     stream << std::endl;
337 }
338
339 void printPerformanceCounters(std::map<std::string,
340         InferenceEngine::InferenceEngineProfileInfo> const &utterancePerfMap,
341                               size_t callsNum,
342                               std::ostream &stream) {
343 #if !defined(__arm__) && !defined(_M_ARM)
344     stream << std::endl << "Performance counts:" << std::endl;
345     stream << std::setw(10) << std::right << "" << "Counter descriptions";
346     stream << std::setw(22) << "Utt scoring time";
347     stream << std::setw(18) << "Avg infer time";
348     stream << std::endl;
349
350     stream << std::setw(46) << "(ms)";
351     stream << std::setw(24) << "(us per call)";
352     stream << std::endl;
353
354     for (const auto &it : utterancePerfMap) {
355         std::string const &counter_name = it.first;
356         float current_units = it.second.realTime_uSec;
357         float call_units = current_units / callsNum;
358         float freq = 1.0;
359         // if GNA HW counters
360         // get frequency of GNA module
361         freq = getGnaFrequencyMHz();
362         current_units /= freq * 1000;
363         call_units /= freq;
364         stream << std::setw(30) << std::left << counter_name.substr(4, counter_name.size() - 1);
365         stream << std::setw(16) << std::right << current_units;
366         stream << std::setw(21) << std::right << call_units;
367         stream << std::endl;
368     }
369     stream << std::endl;
370 #endif
371 }
372
373 void getPerformanceCounters(InferenceEngine::InferRequest &request,
374                             std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &perfCounters) {
375     auto retPerfCounters = request.GetPerformanceCounts();
376
377     for (const auto &pair : retPerfCounters) {
378         perfCounters[pair.first] = pair.second;
379     }
380 }
381
382 void sumPerformanceCounters(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> const &perfCounters,
383                             std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &totalPerfCounters) {
384     for (const auto &pair : perfCounters) {
385         totalPerfCounters[pair.first].realTime_uSec += pair.second.realTime_uSec;
386     }
387 }
388
389 bool ParseAndCheckCommandLine(int argc, char *argv[]) {
390     // ---------------------------Parsing and validation of input args--------------------------------------
391     slog::info << "Parsing input parameters" << slog::endl;
392
393     gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
394     if (FLAGS_h) {
395         showUsage();
396         return false;
397     }
398     bool isDumpMode = !FLAGS_wg.empty() || !FLAGS_we.empty();
399
400     // input not required only in dump mode and if external scale factor provided
401     if (FLAGS_i.empty() && (!isDumpMode || FLAGS_q.compare("user") != 0)) {
402         if (isDumpMode) {
403             throw std::logic_error("In model dump mode either static quantization is used (-i) or user scale"
404                                    " factor need to be provided. See -q user option");
405         }
406         throw std::logic_error("Input file not set. Please use -i.");
407     }
408
409     if (FLAGS_m.empty() && FLAGS_rg.empty()) {
410         throw std::logic_error("Either IR file (-m) or GNAModel file (-rg) need to be set.");
411     }
412
413     if ((!FLAGS_m.empty() && !FLAGS_rg.empty())) {
414         throw std::logic_error("Only one of -m and -rg is allowed.");
415     }
416
417     if ((FLAGS_d.compare("GPU") != 0) && (FLAGS_d.compare("CPU") != 0) && (FLAGS_d.compare("GNA_AUTO") != 0) &&
418         (FLAGS_d.compare("GNA_HW") != 0)
419         && (FLAGS_d.compare("GNA_SW") != 0) && (FLAGS_d.compare("GNA_SW_EXACT") != 0)) {
420         throw std::logic_error("Specified device is not supported.");
421     }
422
423     float scaleFactorInput = static_cast<float>(FLAGS_sf);
424     if (scaleFactorInput <= 0.0f) {
425         throw std::logic_error("Scale factor out of range (must be non-negative).");
426     }
427
428     uint32_t batchSize = (uint32_t) FLAGS_bs;
429     if ((batchSize < 1) || (batchSize > 8)) {
430         throw std::logic_error("Batch size out of range (1..8).");
431     }
432
433     /** default is a static quantisation **/
434     if ((FLAGS_q.compare("static") != 0) && (FLAGS_q.compare("dynamic") != 0) && (FLAGS_q.compare("user") != 0)) {
435         throw std::logic_error("Quantization mode not supported (static, dynamic, user).");
436     }
437
438     if (FLAGS_q.compare("dynamic") == 0) {
439         throw std::logic_error("Dynamic quantization not yet supported.");
440     }
441
442     if (FLAGS_qb != 16 && FLAGS_qb != 8) {
443         throw std::logic_error("Only 8 or 16 bits supported.");
444     }
445
446     if (FLAGS_nthreads <= 0) {
447         throw std::logic_error("Not valid value for 'nthreads' argument. It should be > 0 ");
448     }
449
450     return true;
451 }
452
453 /**
454  * @brief The entry point for inference engine automatic speech recognition sample
455  * @file speech_sample/main.cpp
456  * @example speech_sample/main.cpp
457  */
458 int main(int argc, char *argv[]) {
459     try {
460         slog::info << "InferenceEngine: " << GetInferenceEngineVersion() << slog::endl;
461
462         // ------------------------------ Parsing and validation of input args ---------------------------------
463         if (!ParseAndCheckCommandLine(argc, argv)) {
464             return 0;
465         }
466
467         if (FLAGS_l.empty()) {
468             slog::info << "No extensions provided" << slog::endl;
469         }
470
471         bool useGna = (FLAGS_d.find("GNA") != std::string::npos);
472         auto deviceStr = FLAGS_d.substr(0, (FLAGS_d.find("_")));
473         float scaleFactorInput = static_cast<float>(FLAGS_sf);
474         uint32_t batchSize = (uint32_t) FLAGS_bs;
475         /** Extract input ark file name **/
476         std::string inputArkName = fileNameNoExt(FLAGS_i) + ".ark";
477
478         uint32_t numUtterances(0), numBytesThisUtterance(0);
479         if (!FLAGS_i.empty()) {
480             GetKaldiArkInfo(inputArkName.c_str(), 0, &numUtterances, &numBytesThisUtterance);
481         }
482         // -----------------------------------------------------------------------------------------------------
483
484         // --------------------------- 1. Load Plugin for inference engine -------------------------------------
485         slog::info << "Loading plugin" << slog::endl;
486         /** Loading plugin for device **/
487         InferencePlugin plugin = PluginDispatcher({FLAGS_pp, "../../../lib/intel64", ""}).getPluginByDevice(deviceStr);
488
489         /** Printing plugin version **/
490         std::cout << plugin.GetVersion() << std::endl << std::endl;
491         // -----------------------------------------------------------------------------------------------------
492
493         // --------------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ------------
494         slog::info << "Loading network files" << slog::endl;
495
496         CNNNetReader netBuilder;
497         if (!FLAGS_m.empty()) {
498             /** Read network model **/
499             netBuilder.ReadNetwork(FLAGS_m);
500
501             /** Extract model name and load weights **/
502             std::string binFileName = fileNameNoExt(FLAGS_m) + ".bin";
503             netBuilder.ReadWeights(binFileName);
504
505             // -------------------------------------------------------------------------------------------------
506
507             // --------------------------- 3. Set batch size ---------------------------------------------------
508             /** Set batch size.  Unlike in imaging, batching in time (rather than space) is done for speech recognition. **/
509             netBuilder.getNetwork().setBatchSize(batchSize);
510             slog::info << "Batch size is " << std::to_string(netBuilder.getNetwork().getBatchSize())
511                        << slog::endl;
512         }
513
514         /** Setting plugin parameter for per layer metrics **/
515         std::map<std::string, std::string> gnaPluginConfig;
516         std::map<std::string, std::string> genericPluginConfig;
517         if (FLAGS_d.compare("CPU") != 0) {
518             gnaPluginConfig[GNAConfigParams::KEY_GNA_DEVICE_MODE] = FLAGS_d;
519         }
520         if (FLAGS_pc) {
521             genericPluginConfig[PluginConfigParams::KEY_PERF_COUNT] = PluginConfigParams::YES;
522         }
523
524         if (FLAGS_q.compare("user") == 0) {
525             std::cout << "[ INFO ] Using scale factor of " << FLAGS_sf << std::endl;
526             gnaPluginConfig[GNA_CONFIG_KEY(SCALE_FACTOR)] = std::to_string(FLAGS_sf);
527         } else {  // "static" quantization with calculated scale factor
528             std::string name;
529             std::vector<uint8_t> ptrFeatures;
530             uint32_t numArrays(0), numBytes(0), numFrames(0), numFrameElements(0), numBytesPerElement(0);
531             GetKaldiArkInfo(inputArkName.c_str(), 0, &numArrays, &numBytes);
532             ptrFeatures.resize(numBytes);
533             LoadKaldiArkArray(inputArkName.c_str(),
534                               0,
535                               name,
536                               ptrFeatures,
537                               &numFrames,
538                               &numFrameElements,
539                               &numBytesPerElement);
540             scaleFactorInput =
541                     ScaleFactorForQuantization(ptrFeatures.data(), MAX_VAL_2B_FEAT, numFrames * numFrameElements);
542             slog::info << "Using scale factor of " << scaleFactorInput << " calculated from first utterance."
543                        << slog::endl;
544             gnaPluginConfig[GNA_CONFIG_KEY(SCALE_FACTOR)] = std::to_string(scaleFactorInput);
545         }
546
547         if (FLAGS_qb == 8) {
548             gnaPluginConfig[GNAConfigParams::KEY_GNA_PRECISION] = "I8";
549         } else {
550             gnaPluginConfig[GNAConfigParams::KEY_GNA_PRECISION] = "I16";
551         }
552
553         gnaPluginConfig[GNAConfigParams::KEY_GNA_LIB_N_THREADS] = std::to_string(FLAGS_nthreads);
554         gnaPluginConfig[GNA_CONFIG_KEY(COMPACT_MODE)] = CONFIG_VALUE(NO);
555         // -----------------------------------------------------------------------------------------------------
556
557         // --------------------------- 4. Write model to file --------------------------------------------------
558         // Embedded GNA model dumping (for Intel(R) Speech Enabling Developer Kit)
559         if (!FLAGS_we.empty()) {
560             gnaPluginConfig[GNAConfigParams::KEY_GNA_FIRMWARE_MODEL_IMAGE] = FLAGS_we;
561         }
562         // -----------------------------------------------------------------------------------------------------
563
564         // --------------------------- 5. Loading model to the plugin ------------------------------------------
565
566         if (useGna) {
567             genericPluginConfig.insert(std::begin(gnaPluginConfig), std::end(gnaPluginConfig));
568         }
569         auto t0 = Time::now();
570         ExecutableNetwork executableNet;
571         if (!FLAGS_m.empty()) {
572             slog::info << "Loading model to the plugin" << slog::endl;
573             executableNet = plugin.LoadNetwork(netBuilder.getNetwork(), genericPluginConfig);
574         } else {
575             slog::info << "Importing model to the plugin" << slog::endl;
576             executableNet = plugin.ImportNetwork(FLAGS_rg.c_str(), genericPluginConfig);
577         }
578
579
580         ms loadTime = std::chrono::duration_cast<ms>(Time::now() - t0);
581         slog::info << "Model loading time " << loadTime.count() << " ms" << slog::endl;
582
583         // --------------------------- 6. Exporting gna model using InferenceEngine AOT API---------------------
584         if (!FLAGS_wg.empty()) {
585             slog::info << "Writing GNA Model to file " << FLAGS_wg << slog::endl;
586             t0 = Time::now();
587             executableNet.Export(FLAGS_wg);
588             ms exportTime = std::chrono::duration_cast<ms>(Time::now() - t0);
589             slog::info << "Exporting time " << exportTime.count() << " ms" << slog::endl;
590             return 0;
591         }
592
593         if (!FLAGS_we.empty()) {
594             slog::info << "Exported GNA embedded model to file " << FLAGS_we << slog::endl;
595             return 0;
596         }
597
598         std::vector<std::pair<InferRequest, size_t>> inferRequests(FLAGS_nthreads);
599         for (auto& inferRequest : inferRequests) {
600             inferRequest = {executableNet.CreateInferRequest(), -1};
601         }
602         // -----------------------------------------------------------------------------------------------------
603
604         // --------------------------- 7. Prepare input blobs --------------------------------------------------
605         /** Taking information about all topology inputs **/
606         ConstInputsDataMap cInputInfo = executableNet.GetInputsInfo();
607         InputsDataMap inputInfo;
608         if (!FLAGS_m.empty()) {
609             inputInfo = netBuilder.getNetwork().getInputsInfo();
610         }
611
612         /** Stores all input blobs data **/
613         if (cInputInfo.size() != 1) {
614             throw std::logic_error("Sample supports only topologies with  1 input");
615         }
616
617         Blob::Ptr ptrInputBlob = inferRequests[0].first.GetBlob(cInputInfo.begin()->first);
618
619         /** configure input precision if model loaded from IR **/
620         for (auto &item : inputInfo) {
621             Precision inputPrecision = Precision::FP32;  // specify Precision::I16 to provide quantized inputs
622             item.second->setPrecision(inputPrecision);
623             item.second->getInputData()->layout = NC;  // row major layout
624         }
625
626         // -----------------------------------------------------------------------------------------------------
627
628         // --------------------------- 8. Prepare output blobs -------------------------------------------------
629         ConstOutputsDataMap cOutputInfo(executableNet.GetOutputsInfo());
630         OutputsDataMap outputInfo;
631         if (!FLAGS_m.empty()) {
632             outputInfo = netBuilder.getNetwork().getOutputsInfo();
633         }
634
635         Blob::Ptr ptrOutputBlob = inferRequests[0].first.GetBlob(cOutputInfo.begin()->first);
636
637         for (auto &item : outputInfo) {
638             DataPtr outData = item.second;
639             if (!outData) {
640                 throw std::logic_error("output data pointer is not valid");
641             }
642
643             Precision outputPrecision = Precision::FP32;  // specify Precision::I32 to retrieve quantized outputs
644             outData->setPrecision(outputPrecision);
645             outData->layout = NC;  // row major layout
646         }
647         // -----------------------------------------------------------------------------------------------------
648
649         // --------------------------- 9. Do inference ---------------------------------------------------------
650         std::vector<uint8_t> ptrUtterance;
651         std::vector<uint8_t> ptrScores;
652         std::vector<uint8_t> ptrReferenceScores;
653         score_error_t frameError, totalError;
654
655         for (uint32_t utteranceIndex = 0; utteranceIndex < numUtterances; ++utteranceIndex) {
656             std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> utterancePerfMap;
657             std::string uttName;
658             uint32_t numFrames(0), numFrameElementsInput(0), numBytesPerElementInput(0), n(0);
659             uint32_t numFramesReference(0), numFrameElementsReference(0), numBytesPerElementReference(0),
660                     numBytesReferenceScoreThisUtterance(0);
661             const uint32_t numScoresPerFrame = ptrOutputBlob->size() / batchSize;
662             GetKaldiArkInfo(inputArkName.c_str(), utteranceIndex, &n, &numBytesThisUtterance);
663             ptrUtterance.resize(numBytesThisUtterance);
664             LoadKaldiArkArray(inputArkName.c_str(),
665                               utteranceIndex,
666                               uttName,
667                               ptrUtterance,
668                               &numFrames,
669                               &numFrameElementsInput,
670                               &numBytesPerElementInput);
671
672             uint32_t numFrameElementsInputPadded = numFrameElementsInput;
673
674             if (ptrInputBlob->size() != numFrameElementsInputPadded * batchSize) {
675                 throw std::logic_error("network input size(" + std::to_string(ptrInputBlob->size()) +
676                                        ") mismatch to ark file size (" +
677                                        std::to_string(numFrameElementsInputPadded * batchSize) + ")");
678             }
679             ptrScores.resize(numFrames * numScoresPerFrame * sizeof(float));
680             if (!FLAGS_r.empty()) {
681                 std::string refUtteranceName;
682                 GetKaldiArkInfo(FLAGS_r.c_str(), utteranceIndex, &n, &numBytesReferenceScoreThisUtterance);
683                 ptrReferenceScores.resize(numBytesReferenceScoreThisUtterance);
684                 LoadKaldiArkArray(FLAGS_r.c_str(),
685                                   utteranceIndex,
686                                   refUtteranceName,
687                                   ptrReferenceScores,
688                                   &numFramesReference,
689                                   &numFrameElementsReference,
690                                   &numBytesPerElementReference);
691             }
692
693             double totalTime = 0.0;
694
695             std::cout << "Utterance " << utteranceIndex << ": " << std::endl;
696
697             ClearScoreError(&totalError);
698             totalError.threshold = frameError.threshold = MAX_SCORE_DIFFERENCE;
699             auto inputFrame = &ptrUtterance.front();
700             auto outputFrame = &ptrScores.front();
701
702             size_t frameIndex{0};
703             uint32_t numFramesThisBatch{batchSize};
704
705             auto t0 = Time::now();
706             auto t1 = t0;
707
708             // Doing inference
709             while (frameIndex <= numFrames) {
710                 if (frameIndex == numFrames) {
711                     bool hasRequests = false;
712                     for (auto &inferRequest : inferRequests) {
713                         if (inferRequest.second != -1) {
714                             hasRequests = true;
715                         }
716                     }
717                     if (!hasRequests) {
718                         break;
719                     }
720                 }
721
722                 bool inferRequestFetched = false;
723                 for (auto &inferRequest : inferRequests) {
724                     if (frameIndex == numFrames) {
725                         numFramesThisBatch = 1;
726                     } else {
727                         numFramesThisBatch = (numFrames - frameIndex < batchSize) ? (numFrames - frameIndex) : batchSize;
728                     }
729
730                     if (inferRequest.second != -1) {
731                         StatusCode code = inferRequest.first.Wait(
732                                 InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
733
734                         if (code != StatusCode::OK) {
735                             continue;
736                         }
737
738                         if (!FLAGS_o.empty()) {
739                             Blob::Ptr outputBlob = inferRequest.first.GetBlob(cOutputInfo.begin()->first);
740                             std::memcpy(outputFrame,
741                                         outputBlob->buffer(),
742                                         outputBlob->byteSize());
743                             outputFrame += numScoresPerFrame * sizeof(float);
744                         }
745
746                         if (!FLAGS_r.empty()) {
747                             Blob::Ptr outputBlob = inferRequest.first.GetBlob(cOutputInfo.begin()->first);
748                             CompareScores(outputBlob->buffer().as<float *>(),
749                                           &ptrReferenceScores[inferRequest.second *
750                                                               numFrameElementsReference *
751                                                               numBytesPerElementReference],
752                                           &frameError,
753                                           numFramesThisBatch,
754                                           numFrameElementsReference);
755                             UpdateScoreError(&frameError, &totalError);
756                         }
757                     }
758
759                     inferRequest.second = -1;
760
761                     if (frameIndex == numFrames) {
762                         continue;
763                     }
764
765                     Blob::Ptr inputBlob = inferRequest.first.GetBlob(cInputInfo.begin()->first);
766                     std::memcpy(inputBlob->buffer(),
767                                 inputFrame,
768                                 inputBlob->byteSize());
769
770                     inferRequest.first.StartAsync();
771
772                     inferRequest.second = frameIndex;
773                     frameIndex += numFramesThisBatch;
774                     inputFrame += sizeof(float) * numFrameElementsInput * numFramesThisBatch;
775                     inferRequestFetched |= true;
776                 }
777
778                 if (!inferRequestFetched) {
779                     std::this_thread::sleep_for(std::chrono::milliseconds(1));
780                     continue;
781                 }
782
783                 if (FLAGS_pc) {
784                     std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> callPerfMap;
785                     // retrive new counters
786                     for (auto inferRequest : inferRequests) {
787                         getPerformanceCounters(inferRequest.first, callPerfMap);
788                         // summarize retrived counters with all previous
789                         sumPerformanceCounters(callPerfMap, utterancePerfMap);
790                     }
791                 }
792             }
793             t1 = Time::now();
794
795             fsec fs = t1 - t0;
796             ms d = std::chrono::duration_cast<ms>(fs);
797             totalTime += d.count();
798
799             // resetting state between utterances
800             for (auto &&state : executableNet.QueryState()) {
801                 state.Reset();
802             }
803
804             if (!FLAGS_o.empty()) {
805                 bool shouldAppend = (utteranceIndex == 0) ? false : true;
806                 SaveKaldiArkArray(FLAGS_o.c_str(), shouldAppend, uttName, &ptrScores.front(),
807                                   numFrames, numScoresPerFrame);
808             }
809
810             /** Show performance results **/
811             std::cout << "Total time in Infer (HW and SW):\t" << totalTime << " ms"
812                       << std::endl;
813             std::cout << "Frames in utterance:\t\t\t" << numFrames << " frames"
814                       << std::endl;
815             std::cout << "Average Infer time per frame:\t\t" << totalTime / static_cast<double>(numFrames) << " ms"
816                       << std::endl;
817             if (FLAGS_pc) {
818                 // print
819                 printPerformanceCounters(utterancePerfMap, frameIndex, std::cout);
820             }
821             if (!FLAGS_r.empty()) {
822                 printReferenceCompareResults(totalError, numFrames, std::cout);
823             }
824             std::cout << "End of Utterance " << utteranceIndex << std::endl << std::endl;
825         }
826         // -----------------------------------------------------------------------------------------------------
827     }
828     catch (const std::exception &error) {
829         slog::err << error.what() << slog::endl;
830         return 1;
831     }
832     catch (...) {
833         slog::err << "Unknown/internal exception happened" << slog::endl;
834         return 1;
835     }
836
837     slog::info << "Execution successful" << slog::endl;
838     return 0;
839 }