efc38ca398c2032e27f440c1646ffc463c0b89a4
[platform/upstream/dldt.git] / inference-engine / samples / speech_sample / main.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "speech_sample.hpp"
6
7 #include <gflags/gflags.h>
8 #include <functional>
9 #include <iostream>
10 #include <memory>
11 #include <map>
12 #include <fstream>
13 #include <random>
14 #include <string>
15 #include <vector>
16 #include <utility>
17 #include <time.h>
18 #include <thread>
19 #include <chrono>
20 #include <limits>
21 #include <iomanip>
22 #include <inference_engine.hpp>
23 #include <gna/gna_config.hpp>
24
25 #include <samples/common.hpp>
26 #include <samples/slog.hpp>
27 #include <samples/args_helper.hpp>
28 #include <ext_list.hpp>
29
30 #ifndef ALIGN
31 #define ALIGN(memSize, pad)   ((static_cast<int>((memSize) + pad - 1) / pad) * pad)
32 #endif
33 #define MAX_SCORE_DIFFERENCE 0.0001f
34 #define MAX_VAL_2B_FEAT 16384
35
36 using namespace InferenceEngine;
37
38 typedef std::chrono::high_resolution_clock Time;
39 typedef std::chrono::duration<double, std::ratio<1, 1000>> ms;
40 typedef std::chrono::duration<float> fsec;
41 typedef struct {
42     uint32_t numScores;
43     uint32_t numErrors;
44     float threshold;
45     float maxError;
46     float rmsError;
47     float sumError;
48     float sumRmsError;
49     float sumSquaredError;
50     float maxRelError;
51     float sumRelError;
52     float sumSquaredRelError;
53 } score_error_t;
54
55 struct InferRequestStruct {
56     InferRequest inferRequest;
57     int frameIndex;
58     uint32_t numFramesThisBatch;
59 };
60
61 void GetKaldiArkInfo(const char *fileName,
62                      uint32_t numArrayToFindSize,
63                      uint32_t *ptrNumArrays,
64                      uint32_t *ptrNumMemoryBytes) {
65     uint32_t numArrays = 0;
66     uint32_t numMemoryBytes = 0;
67
68     std::ifstream in_file(fileName, std::ios::binary);
69     if (in_file.good()) {
70         while (!in_file.eof()) {
71             std::string line;
72             uint32_t numRows = 0u, numCols = 0u, num_bytes = 0u;
73             std::getline(in_file, line, '\0');  // read variable length name followed by space and NUL
74             std::getline(in_file, line, '\4');  // read "BFM" followed by space and control-D
75             if (line.compare("BFM ") != 0) {
76                 break;
77             }
78             in_file.read(reinterpret_cast<char *>(&numRows), sizeof(uint32_t));  // read number of rows
79             std::getline(in_file, line, '\4');                                   // read control-D
80             in_file.read(reinterpret_cast<char *>(&numCols), sizeof(uint32_t));  // read number of columns
81             num_bytes = numRows * numCols * sizeof(float);
82             in_file.seekg(num_bytes, in_file.cur);                               // read data
83
84             if (numArrays == numArrayToFindSize) {
85                 numMemoryBytes += num_bytes;
86             }
87             numArrays++;
88         }
89         in_file.close();
90     } else {
91         fprintf(stderr, "Failed to open %s for reading in GetKaldiArkInfo()!\n", fileName);
92         exit(-1);
93     }
94
95     if (ptrNumArrays != NULL) *ptrNumArrays = numArrays;
96     if (ptrNumMemoryBytes != NULL) *ptrNumMemoryBytes = numMemoryBytes;
97 }
98
99 void LoadKaldiArkArray(const char *fileName, uint32_t arrayIndex, std::string &ptrName, std::vector<uint8_t> &memory,
100                        uint32_t *ptrNumRows, uint32_t *ptrNumColumns, uint32_t *ptrNumBytesPerElement) {
101     std::ifstream in_file(fileName, std::ios::binary);
102     if (in_file.good()) {
103         uint32_t i = 0;
104         while (i < arrayIndex) {
105             std::string line;
106             uint32_t numRows = 0u, numCols = 0u;
107             std::getline(in_file, line, '\0');  // read variable length name followed by space and NUL
108             std::getline(in_file, line, '\4');  // read "BFM" followed by space and control-D
109             if (line.compare("BFM ") != 0) {
110                 break;
111             }
112             in_file.read(reinterpret_cast<char *>(&numRows), sizeof(uint32_t));     // read number of rows
113             std::getline(in_file, line, '\4');                                     // read control-D
114             in_file.read(reinterpret_cast<char *>(&numCols), sizeof(uint32_t));     // read number of columns
115             in_file.seekg(numRows * numCols * sizeof(float), in_file.cur);         // read data
116             i++;
117         }
118         if (!in_file.eof()) {
119             std::string line;
120             std::getline(in_file, ptrName, '\0');     // read variable length name followed by space and NUL
121             std::getline(in_file, line, '\4');       // read "BFM" followed by space and control-D
122             if (line.compare("BFM ") != 0) {
123                 fprintf(stderr, "Cannot find array specifier in file %s in LoadKaldiArkArray()!\n", fileName);
124                 exit(-1);
125             }
126             in_file.read(reinterpret_cast<char *>(ptrNumRows), sizeof(uint32_t));        // read number of rows
127             std::getline(in_file, line, '\4');                                            // read control-D
128             in_file.read(reinterpret_cast<char *>(ptrNumColumns), sizeof(uint32_t));    // read number of columns
129             in_file.read(reinterpret_cast<char *>(&memory.front()),
130                          *ptrNumRows * *ptrNumColumns * sizeof(float));  // read array data
131         }
132         in_file.close();
133     } else {
134         fprintf(stderr, "Failed to open %s for reading in GetKaldiArkInfo()!\n", fileName);
135         exit(-1);
136     }
137
138     *ptrNumBytesPerElement = sizeof(float);
139 }
140
141 void SaveKaldiArkArray(const char *fileName,
142                        bool shouldAppend,
143                        std::string name,
144                        void *ptrMemory,
145                        uint32_t numRows,
146                        uint32_t numColumns) {
147     std::ios_base::openmode mode = std::ios::binary;
148     if (shouldAppend) {
149         mode |= std::ios::app;
150     }
151     std::ofstream out_file(fileName, mode);
152     if (out_file.good()) {
153         out_file.write(name.c_str(), name.length());  // write name
154         out_file.write("\0", 1);
155         out_file.write("BFM ", 4);
156         out_file.write("\4", 1);
157         out_file.write(reinterpret_cast<char *>(&numRows), sizeof(uint32_t));
158         out_file.write("\4", 1);
159         out_file.write(reinterpret_cast<char *>(&numColumns), sizeof(uint32_t));
160         out_file.write(reinterpret_cast<char *>(ptrMemory), numRows * numColumns * sizeof(float));
161         out_file.close();
162     } else {
163         throw std::runtime_error(std::string("Failed to open %s for writing in SaveKaldiArkArray()!\n") + fileName);
164     }
165 }
166
167 float ScaleFactorForQuantization(void *ptrFloatMemory, float targetMax, uint32_t numElements) {
168     float *ptrFloatFeat = reinterpret_cast<float *>(ptrFloatMemory);
169     float max = 0.0;
170     float scaleFactor;
171
172     for (uint32_t i = 0; i < numElements; i++) {
173         if (fabs(ptrFloatFeat[i]) > max) {
174             max = fabs(ptrFloatFeat[i]);
175         }
176     }
177
178     if (max == 0) {
179         scaleFactor = 1.0;
180     } else {
181         scaleFactor = targetMax / max;
182     }
183
184     return (scaleFactor);
185 }
186
187 void ClearScoreError(score_error_t *error) {
188     error->numScores = 0;
189     error->numErrors = 0;
190     error->maxError = 0.0;
191     error->rmsError = 0.0;
192     error->sumError = 0.0;
193     error->sumRmsError = 0.0;
194     error->sumSquaredError = 0.0;
195     error->maxRelError = 0.0;
196     error->sumRelError = 0.0;
197     error->sumSquaredRelError = 0.0;
198 }
199
200 void UpdateScoreError(score_error_t *error, score_error_t *totalError) {
201     totalError->numErrors += error->numErrors;
202     totalError->numScores += error->numScores;
203     totalError->sumRmsError += error->rmsError;
204     totalError->sumError += error->sumError;
205     totalError->sumSquaredError += error->sumSquaredError;
206     if (error->maxError > totalError->maxError) {
207         totalError->maxError = error->maxError;
208     }
209     totalError->sumRelError += error->sumRelError;
210     totalError->sumSquaredRelError += error->sumSquaredRelError;
211     if (error->maxRelError > totalError->maxRelError) {
212         totalError->maxRelError = error->maxRelError;
213     }
214 }
215
216 uint32_t CompareScores(float *ptrScoreArray,
217                        void *ptrRefScoreArray,
218                        score_error_t *scoreError,
219                        uint32_t numRows,
220                        uint32_t numColumns) {
221     uint32_t numErrors = 0;
222
223     ClearScoreError(scoreError);
224
225     float *A = ptrScoreArray;
226     float *B = reinterpret_cast<float *>(ptrRefScoreArray);
227     for (uint32_t i = 0; i < numRows; i++) {
228         for (uint32_t j = 0; j < numColumns; j++) {
229             float score = A[i * numColumns + j];
230             float refscore = B[i * numColumns + j];
231             float error = fabs(refscore - score);
232             float rel_error = error / (static_cast<float>(fabs(refscore)) + 1e-20f);
233             float squared_error = error * error;
234             float squared_rel_error = rel_error * rel_error;
235             scoreError->numScores++;
236             scoreError->sumError += error;
237             scoreError->sumSquaredError += squared_error;
238             if (error > scoreError->maxError) {
239                 scoreError->maxError = error;
240             }
241             scoreError->sumRelError += rel_error;
242             scoreError->sumSquaredRelError += squared_rel_error;
243             if (rel_error > scoreError->maxRelError) {
244                 scoreError->maxRelError = rel_error;
245             }
246             if (error > scoreError->threshold) {
247                 numErrors++;
248             }
249         }
250     }
251     scoreError->rmsError = sqrt(scoreError->sumSquaredError / (numRows * numColumns));
252     scoreError->sumRmsError += scoreError->rmsError;
253     scoreError->numErrors = numErrors;
254
255     return (numErrors);
256 }
257
258 float StdDevError(score_error_t error) {
259     return (sqrt(error.sumSquaredError / error.numScores
260                  - (error.sumError / error.numScores) * (error.sumError / error.numScores)));
261 }
262
263 float StdDevRelError(score_error_t error) {
264     return (sqrt(error.sumSquaredRelError / error.numScores
265                  - (error.sumRelError / error.numScores) * (error.sumRelError / error.numScores)));
266 }
267
268 #if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
269 #if defined(_WIN32) || defined(WIN32)
270 #include <intrin.h>
271 #include <windows.h>
272 #else
273
274 #include <cpuid.h>
275
276 #endif
277
278 inline void native_cpuid(unsigned int *eax, unsigned int *ebx,
279                          unsigned int *ecx, unsigned int *edx) {
280     size_t level = *eax;
281 #if defined(_WIN32) || defined(WIN32)
282     int regs[4] = {static_cast<int>(*eax), static_cast<int>(*ebx), static_cast<int>(*ecx), static_cast<int>(*edx)};
283     __cpuid(regs, level);
284     *eax = static_cast<uint32_t>(regs[0]);
285     *ebx = static_cast<uint32_t>(regs[1]);
286     *ecx = static_cast<uint32_t>(regs[2]);
287     *edx = static_cast<uint32_t>(regs[3]);
288 #else
289     __get_cpuid(level, eax, ebx, ecx, edx);
290 #endif
291 }
292
293 // return GNA module frequency in MHz
294 float getGnaFrequencyMHz() {
295     uint32_t eax = 1;
296     uint32_t ebx = 0;
297     uint32_t ecx = 0;
298     uint32_t edx = 0;
299     uint32_t family = 0;
300     uint32_t model = 0;
301     const uint8_t sixth_family = 6;
302     const uint8_t cannon_lake_model = 102;
303     const uint8_t gemini_lake_model = 122;
304
305     native_cpuid(&eax, &ebx, &ecx, &edx);
306     family = (eax >> 8) & 0xF;
307
308     // model is the concatenation of two fields
309     // | extended model | model |
310     // copy extended model data
311     model = (eax >> 16) & 0xF;
312     // shift
313     model <<= 4;
314     // copy model data
315     model += (eax >> 4) & 0xF;
316
317     if (family == sixth_family && model == cannon_lake_model) {
318         return 400;
319     } else if (family == sixth_family &&
320                model == gemini_lake_model) {
321         return 200;
322     } else {
323         // counters not supported and we retrns just default value
324         return 1;
325     }
326 }
327
328 #endif  // if not ARM
329
330 void printReferenceCompareResults(score_error_t const &totalError,
331                                   size_t framesNum,
332                                   std::ostream &stream) {
333     stream << "         max error: " <<
334            totalError.maxError << std::endl;
335     stream << "         avg error: " <<
336            totalError.sumError / totalError.numScores << std::endl;
337     stream << "     avg rms error: " <<
338            totalError.sumRmsError / framesNum << std::endl;
339     stream << "       stdev error: " <<
340            StdDevError(totalError) << std::endl << std::endl;
341     stream << std::endl;
342 }
343
344 void printPerformanceCounters(std::map<std::string,
345         InferenceEngine::InferenceEngineProfileInfo> const &utterancePerfMap,
346                               size_t callsNum,
347                               std::ostream &stream, std::string fullDeviceName) {
348 #if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
349     stream << std::endl << "Performance counts:" << std::endl;
350     stream << std::setw(10) << std::right << "" << "Counter descriptions";
351     stream << std::setw(22) << "Utt scoring time";
352     stream << std::setw(18) << "Avg infer time";
353     stream << std::endl;
354
355     stream << std::setw(46) << "(ms)";
356     stream << std::setw(24) << "(us per call)";
357     stream << std::endl;
358
359     for (const auto &it : utterancePerfMap) {
360         std::string const &counter_name = it.first;
361         float current_units = static_cast<float>(it.second.realTime_uSec);
362         float call_units = current_units / callsNum;
363         // if GNA HW counters
364         // get frequency of GNA module
365         float freq = getGnaFrequencyMHz();
366         current_units /= freq * 1000;
367         call_units /= freq;
368         stream << std::setw(30) << std::left << counter_name.substr(4, counter_name.size() - 1);
369         stream << std::setw(16) << std::right << current_units;
370         stream << std::setw(21) << std::right << call_units;
371         stream << std::endl;
372     }
373     stream << std::endl;
374     std::cout << std::endl;
375     std::cout << "Full device name: " << fullDeviceName << std::endl;
376     std::cout << std::endl;
377 #endif
378 }
379
380 void getPerformanceCounters(InferenceEngine::InferRequest &request,
381                             std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &perfCounters) {
382     auto retPerfCounters = request.GetPerformanceCounts();
383
384     for (const auto &pair : retPerfCounters) {
385         perfCounters[pair.first] = pair.second;
386     }
387 }
388
389 void sumPerformanceCounters(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> const &perfCounters,
390                             std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &totalPerfCounters) {
391     for (const auto &pair : perfCounters) {
392         totalPerfCounters[pair.first].realTime_uSec += pair.second.realTime_uSec;
393     }
394 }
395
396 bool ParseAndCheckCommandLine(int argc, char *argv[]) {
397     // ---------------------------Parsing and validation of input args--------------------------------------
398     slog::info << "Parsing input parameters" << slog::endl;
399
400     gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
401     if (FLAGS_h) {
402         showUsage();
403         showAvailableDevices();
404         return false;
405     }
406     bool isDumpMode = !FLAGS_wg.empty() || !FLAGS_we.empty();
407
408     // input not required only in dump mode and if external scale factor provided
409     if (FLAGS_i.empty() && (!isDumpMode || FLAGS_q.compare("user") != 0)) {
410         if (isDumpMode) {
411             throw std::logic_error("In model dump mode either static quantization is used (-i) or user scale"
412                                    " factor need to be provided. See -q user option");
413         }
414         throw std::logic_error("Input file not set. Please use -i.");
415     }
416
417     if (FLAGS_m.empty() && FLAGS_rg.empty()) {
418         throw std::logic_error("Either IR file (-m) or GNAModel file (-rg) need to be set.");
419     }
420
421     if ((!FLAGS_m.empty() && !FLAGS_rg.empty())) {
422         throw std::logic_error("Only one of -m and -rg is allowed.");
423     }
424
425     std::vector<std::string> supportedDevices = {
426             "CPU",
427             "GPU",
428             "GNA_AUTO",
429             "GNA_HW",
430             "GNA_SW_EXACT",
431             "GNA_SW",
432             "GNA_SW_FP32",
433             "HETERO:GNA,CPU",
434             "HETERO:GNA_HW,CPU",
435             "HETERO:GNA_SW_EXACT,CPU",
436             "HETERO:GNA_SW,CPU",
437             "HETERO:GNA_SW_FP32,CPU",
438             "MYRIAD"
439     };
440
441     if (std::find(supportedDevices.begin(), supportedDevices.end(), FLAGS_d) == supportedDevices.end()) {
442         throw std::logic_error("Specified device is not supported.");
443     }
444
445     float scaleFactorInput = static_cast<float>(FLAGS_sf);
446     if (scaleFactorInput <= 0.0f) {
447         throw std::logic_error("Scale factor out of range (must be non-negative).");
448     }
449
450     uint32_t batchSize = (uint32_t) FLAGS_bs;
451     if ((batchSize < 1) || (batchSize > 8)) {
452         throw std::logic_error("Batch size out of range (1..8).");
453     }
454
455     /** default is a static quantisation **/
456     if ((FLAGS_q.compare("static") != 0) && (FLAGS_q.compare("dynamic") != 0) && (FLAGS_q.compare("user") != 0)) {
457         throw std::logic_error("Quantization mode not supported (static, dynamic, user).");
458     }
459
460     if (FLAGS_q.compare("dynamic") == 0) {
461         throw std::logic_error("Dynamic quantization not yet supported.");
462     }
463
464     if (FLAGS_qb != 16 && FLAGS_qb != 8) {
465         throw std::logic_error("Only 8 or 16 bits supported.");
466     }
467
468     if (FLAGS_nthreads <= 0) {
469         throw std::logic_error("Invalid value for 'nthreads' argument. It must be greater that or equal to 0");
470     }
471
472     if (FLAGS_cw_r < 0) {
473         throw std::logic_error("Invalid value for 'cw_r' argument. It must be greater than or equal to 0");
474     }
475
476     if (FLAGS_cw_l < 0) {
477         throw std::logic_error("Invalid value for 'cw_l' argument. It must be greater than or equal to 0");
478     }
479
480     return true;
481 }
482
483 /**
484  * @brief The entry point for inference engine automatic speech recognition sample
485  * @file speech_sample/main.cpp
486  * @example speech_sample/main.cpp
487  */
488 int main(int argc, char *argv[]) {
489     try {
490         slog::info << "InferenceEngine: " << GetInferenceEngineVersion() << slog::endl;
491
492         // ------------------------------ Parsing and validation of input args ---------------------------------
493         if (!ParseAndCheckCommandLine(argc, argv)) {
494             return 0;
495         }
496
497         if (FLAGS_l.empty()) {
498             slog::info << "No extensions provided" << slog::endl;
499         }
500
501         auto isFeature = [&](const std::string xFeature) { return FLAGS_d.find(xFeature) != std::string::npos; };
502
503         bool useGna = isFeature("GNA");
504         bool useHetero = isFeature("HETERO");
505         std::string deviceStr =
506                 useHetero && useGna ? "HETERO:GNA,CPU" : FLAGS_d.substr(0, (FLAGS_d.find("_")));
507         float scaleFactorInput = static_cast<float>(FLAGS_sf);
508         uint32_t batchSize = (FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : (uint32_t) FLAGS_bs;
509
510         std::vector<std::string> inputArkFiles;
511         std::vector<uint32_t> numBytesThisUtterance;
512         uint32_t numUtterances(0);
513         if (!FLAGS_i.empty()) {
514             std::string outStr;
515             std::istringstream stream(FLAGS_i);
516
517             uint32_t currentNumUtterances(0), currentNumBytesThisUtterance(0);
518             while (getline(stream, outStr, ',')) {
519                 std::string filename(fileNameNoExt(outStr) + ".ark");
520                 inputArkFiles.push_back(filename);
521
522                 GetKaldiArkInfo(filename.c_str(), 0, &currentNumUtterances, &currentNumBytesThisUtterance);
523                 if (numUtterances == 0) {
524                     numUtterances = currentNumUtterances;
525                 } else if (currentNumUtterances != numUtterances) {
526                     throw std::logic_error("Incorrect input files. Number of utterance must be the same for all ark files");
527                 }
528                 numBytesThisUtterance.push_back(currentNumBytesThisUtterance);
529             }
530         }
531         size_t numInputArkFiles(inputArkFiles.size());
532         // -----------------------------------------------------------------------------------------------------
533
534         // --------------------------- 1. Load inference engine -------------------------------------
535         slog::info << "Loading Inference Engine" << slog::endl;
536         Core ie;
537
538         /** Printing device version **/
539         slog::info << "Device info: " << slog::endl;
540         std::cout << ie.GetVersions(deviceStr) << std::endl;
541         // -----------------------------------------------------------------------------------------------------
542
543         // --------------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ------------
544         slog::info << "Loading network files" << slog::endl;
545
546         CNNNetReader netBuilder;
547         if (!FLAGS_m.empty()) {
548             /** Read network model **/
549             netBuilder.ReadNetwork(FLAGS_m);
550
551             /** Extract model name and load weights **/
552             std::string binFileName = fileNameNoExt(FLAGS_m) + ".bin";
553             netBuilder.ReadWeights(binFileName);
554
555             // -------------------------------------------------------------------------------------------------
556
557             // --------------------------- 3. Set batch size ---------------------------------------------------
558             /** Set batch size.  Unlike in imaging, batching in time (rather than space) is done for speech recognition. **/
559             netBuilder.getNetwork().setBatchSize(batchSize);
560             slog::info << "Batch size is " << std::to_string(netBuilder.getNetwork().getBatchSize())
561                        << slog::endl;
562         }
563
564         /** Setting parameter for per layer metrics **/
565         std::map<std::string, std::string> gnaPluginConfig;
566         std::map<std::string, std::string> genericPluginConfig;
567         if (useGna) {
568             std::string gnaDevice =
569                     useHetero ? FLAGS_d.substr(FLAGS_d.find("GNA"), FLAGS_d.find(",") - FLAGS_d.find("GNA")) : FLAGS_d;
570             gnaPluginConfig[GNAConfigParams::KEY_GNA_DEVICE_MODE] =
571                     gnaDevice.find("_") == std::string::npos ? "GNA_AUTO" : gnaDevice;
572         } else if (deviceStr.find("CPU") != std::string::npos) {
573             /**
574              * cpu_extensions library is compiled from "extension" folder containing
575              * custom MKLDNNPlugin layer implementations. These layers are not supported
576              * by mkldnn, but they can be useful for inferring custom topologies.
577             **/
578             ie.AddExtension(std::make_shared<Extensions::Cpu::CpuExtensions>(), "CPU");
579         }
580
581         if (FLAGS_pc) {
582             genericPluginConfig[PluginConfigParams::KEY_PERF_COUNT] = PluginConfigParams::YES;
583         }
584
585         if (FLAGS_q.compare("user") == 0) {
586             if (numInputArkFiles > 1) {
587                 std::string errMessage("Incorrect use case for multiple input ark files. Please don't use -q 'user' for this case.");
588                 throw std::logic_error(errMessage);
589             }
590             slog::info << "Using scale factor of " << FLAGS_sf << slog::endl;
591             gnaPluginConfig[GNA_CONFIG_KEY(SCALE_FACTOR)] = std::to_string(FLAGS_sf);
592         } else {
593             // "static" quantization with calculated scale factor
594             for (size_t i = 0; i < numInputArkFiles; i++) {
595                 auto inputArkName = inputArkFiles[i].c_str();
596                 std::string name;
597                 std::vector<uint8_t> ptrFeatures;
598                 uint32_t numArrays(0), numBytes(0), numFrames(0), numFrameElements(0), numBytesPerElement(0);
599                 GetKaldiArkInfo(inputArkName, 0, &numArrays, &numBytes);
600                 ptrFeatures.resize(numBytes);
601                 LoadKaldiArkArray(inputArkName,
602                                   0,
603                                   name,
604                                   ptrFeatures,
605                                   &numFrames,
606                                   &numFrameElements,
607                                   &numBytesPerElement);
608                 scaleFactorInput =
609                         ScaleFactorForQuantization(ptrFeatures.data(), MAX_VAL_2B_FEAT, numFrames * numFrameElements);
610                 slog::info << "Using scale factor of " << scaleFactorInput << " calculated from first utterance."
611                            << slog::endl;
612                 std::string scaleFactorConfigKey = GNA_CONFIG_KEY(SCALE_FACTOR) + std::string("_") + std::to_string(i);
613                 gnaPluginConfig[scaleFactorConfigKey] = std::to_string(scaleFactorInput);
614             }
615         }
616
617         if (FLAGS_qb == 8) {
618             gnaPluginConfig[GNAConfigParams::KEY_GNA_PRECISION] = "I8";
619         } else {
620             gnaPluginConfig[GNAConfigParams::KEY_GNA_PRECISION] = "I16";
621         }
622
623         gnaPluginConfig[GNAConfigParams::KEY_GNA_LIB_N_THREADS] = std::to_string((FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : FLAGS_nthreads);
624         gnaPluginConfig[GNA_CONFIG_KEY(COMPACT_MODE)] = CONFIG_VALUE(NO);
625         // -----------------------------------------------------------------------------------------------------
626
627         // --------------------------- 4. Write model to file --------------------------------------------------
628         // Embedded GNA model dumping (for Intel(R) Speech Enabling Developer Kit)
629         if (!FLAGS_we.empty()) {
630             gnaPluginConfig[GNAConfigParams::KEY_GNA_FIRMWARE_MODEL_IMAGE] = FLAGS_we;
631         }
632         // -----------------------------------------------------------------------------------------------------
633
634         // --------------------------- 5. Loading model to the device ------------------------------------------
635
636         if (useGna) {
637             genericPluginConfig.insert(std::begin(gnaPluginConfig), std::end(gnaPluginConfig));
638         }
639         auto t0 = Time::now();
640         ExecutableNetwork executableNet;
641
642         if (!FLAGS_m.empty()) {
643             slog::info << "Loading model to the device" << slog::endl;
644             executableNet = ie.LoadNetwork(netBuilder.getNetwork(), deviceStr, genericPluginConfig);
645         } else {
646             slog::info << "Importing model to the device" << slog::endl;
647             executableNet = ie.ImportNetwork(FLAGS_rg.c_str(), deviceStr, genericPluginConfig);
648         }
649
650         ms loadTime = std::chrono::duration_cast<ms>(Time::now() - t0);
651         slog::info << "Model loading time " << loadTime.count() << " ms" << slog::endl;
652
653         // --------------------------- 6. Exporting gna model using InferenceEngine AOT API---------------------
654         if (!FLAGS_wg.empty()) {
655             slog::info << "Writing GNA Model to file " << FLAGS_wg << slog::endl;
656             t0 = Time::now();
657             executableNet.Export(FLAGS_wg);
658             ms exportTime = std::chrono::duration_cast<ms>(Time::now() - t0);
659             slog::info << "Exporting time " << exportTime.count() << " ms" << slog::endl;
660             return 0;
661         }
662
663         if (!FLAGS_we.empty()) {
664             slog::info << "Exported GNA embedded model to file " << FLAGS_we << slog::endl;
665             return 0;
666         }
667
668         std::vector<InferRequestStruct> inferRequests((FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : FLAGS_nthreads);
669         for (auto& inferRequest : inferRequests) {
670             inferRequest = {executableNet.CreateInferRequest(), -1, batchSize};
671         }
672         // -----------------------------------------------------------------------------------------------------
673
674         // --------------------------- 7. Prepare input blobs --------------------------------------------------
675         /** Taking information about all topology inputs **/
676         ConstInputsDataMap cInputInfo = executableNet.GetInputsInfo();
677         /** Stores all input blobs data **/
678         if (cInputInfo.size() != numInputArkFiles) {
679             throw std::logic_error("Number of network inputs("
680                 + std::to_string(cInputInfo.size()) + ") is not equal to number of ark files("
681                 + std::to_string(numInputArkFiles) + ")");
682         }
683
684         std::vector<Blob::Ptr> ptrInputBlobs;
685         for (auto& input : cInputInfo) {
686             ptrInputBlobs.push_back(inferRequests.begin()->inferRequest.GetBlob(input.first));
687         }
688
689         InputsDataMap inputInfo;
690         if (!FLAGS_m.empty()) {
691             inputInfo = netBuilder.getNetwork().getInputsInfo();
692         }
693         /** configure input precision if model loaded from IR **/
694         for (auto &item : inputInfo) {
695             Precision inputPrecision = Precision::FP32;  // specify Precision::I16 to provide quantized inputs
696             item.second->setPrecision(inputPrecision);
697             item.second->getInputData()->setLayout(Layout::NC);  // row major layout
698         }
699
700         // -----------------------------------------------------------------------------------------------------
701
702         // --------------------------- 8. Prepare output blobs -------------------------------------------------
703         ConstOutputsDataMap cOutputInfo(executableNet.GetOutputsInfo());
704         OutputsDataMap outputInfo;
705         if (!FLAGS_m.empty()) {
706             outputInfo = netBuilder.getNetwork().getOutputsInfo();
707         }
708
709         Blob::Ptr ptrOutputBlob = inferRequests[0].inferRequest.GetBlob(cOutputInfo.rbegin()->first);
710
711         for (auto &item : outputInfo) {
712             DataPtr outData = item.second;
713             if (!outData) {
714                 throw std::logic_error("output data pointer is not valid");
715             }
716
717             Precision outputPrecision = Precision::FP32;  // specify Precision::I32 to retrieve quantized outputs
718             outData->setPrecision(outputPrecision);
719             outData->setLayout(Layout::NC);  // row major layout
720         }
721         // -----------------------------------------------------------------------------------------------------
722
723         // --------------------------- 9. Do inference ---------------------------------------------------------
724         std::vector<std::vector<uint8_t>> ptrUtterances;
725         std::vector<uint8_t> ptrScores;
726         std::vector<uint8_t> ptrReferenceScores;
727         score_error_t frameError, totalError;
728
729         ptrUtterances.resize(inputArkFiles.size());
730         for (uint32_t utteranceIndex = 0; utteranceIndex < numUtterances; ++utteranceIndex) {
731             std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> utterancePerfMap;
732             std::string uttName;
733             uint32_t numFrames(0), n(0);
734             std::vector<uint32_t> numFrameElementsInput;
735
736             uint32_t numFramesReference(0), numFrameElementsReference(0), numBytesPerElementReference(0),
737                     numBytesReferenceScoreThisUtterance(0);
738             const uint32_t numScoresPerFrame = ptrOutputBlob->size() / batchSize;
739
740             numFrameElementsInput.resize(numInputArkFiles);
741             for (size_t i = 0; i < inputArkFiles.size(); i++) {
742                 std::vector<uint8_t> ptrUtterance;
743                 auto inputArkFilename = inputArkFiles[i].c_str();
744                 uint32_t currentNumFrames(0), currentNumFrameElementsInput(0), currentNumBytesPerElementInput(0);
745                 GetKaldiArkInfo(inputArkFilename, utteranceIndex, &n, &numBytesThisUtterance[i]);
746                 ptrUtterance.resize(numBytesThisUtterance[i]);
747                 LoadKaldiArkArray(inputArkFilename,
748                                   utteranceIndex,
749                                   uttName,
750                                   ptrUtterance,
751                                   &currentNumFrames,
752                                   &currentNumFrameElementsInput,
753                                   &currentNumBytesPerElementInput);
754                 if (numFrames == 0) {
755                     numFrames = currentNumFrames;
756                 } else if (numFrames != currentNumFrames) {
757                     std::string errMessage("Number of frames in ark files is different: " + std::to_string(numFrames) +
758                                            " and " + std::to_string(currentNumFrames));
759                     throw std::logic_error(errMessage);
760                 }
761
762                 ptrUtterances[i] = ptrUtterance;
763                 numFrameElementsInput[i] = currentNumFrameElementsInput;
764             }
765
766             int i = 0;
767             for (auto& ptrInputBlob : ptrInputBlobs) {
768                 if (ptrInputBlob->size() != numFrameElementsInput[i++] * batchSize) {
769                     throw std::logic_error("network input size(" + std::to_string(ptrInputBlob->size()) +
770                                            ") mismatch to ark file size (" +
771                                            std::to_string(numFrameElementsInput[i-1] * batchSize) + ")");
772                 }
773             }
774
775             ptrScores.resize(numFrames * numScoresPerFrame * sizeof(float));
776             if (!FLAGS_r.empty()) {
777                 std::string refUtteranceName;
778                 GetKaldiArkInfo(FLAGS_r.c_str(), utteranceIndex, &n, &numBytesReferenceScoreThisUtterance);
779                 ptrReferenceScores.resize(numBytesReferenceScoreThisUtterance);
780                 LoadKaldiArkArray(FLAGS_r.c_str(),
781                                   utteranceIndex,
782                                   refUtteranceName,
783                                   ptrReferenceScores,
784                                   &numFramesReference,
785                                   &numFrameElementsReference,
786                                   &numBytesPerElementReference);
787             }
788
789             double totalTime = 0.0;
790
791             std::cout << "Utterance " << utteranceIndex << ": " << std::endl;
792
793             ClearScoreError(&totalError);
794             totalError.threshold = frameError.threshold = MAX_SCORE_DIFFERENCE;
795             auto outputFrame = &ptrScores.front();
796             std::vector<uint8_t*> inputFrame;
797             for (auto& ut : ptrUtterances) {
798                 inputFrame.push_back(&ut.front());
799             }
800
801             std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> callPerfMap;
802
803             size_t frameIndex = 0;
804             uint32_t numFramesArkFile = numFrames;
805             numFrames += FLAGS_cw_l + FLAGS_cw_r;
806             uint32_t numFramesThisBatch{batchSize};
807
808             auto t0 = Time::now();
809             auto t1 = t0;
810
811             while (frameIndex <= numFrames) {
812                 if (frameIndex == numFrames) {
813                     if (std::find_if(inferRequests.begin(),
814                             inferRequests.end(),
815                             [&](InferRequestStruct x) { return (x.frameIndex != -1); } ) == inferRequests.end()) {
816                         break;
817                     }
818                 }
819
820                 bool inferRequestFetched = false;
821                 for (auto &inferRequest : inferRequests) {
822                     if (frameIndex == numFrames) {
823                         numFramesThisBatch = 1;
824                     } else {
825                         numFramesThisBatch = (numFrames - frameIndex < batchSize) ? (numFrames - frameIndex)
826                                                                                   : batchSize;
827                     }
828
829                     if (inferRequest.frameIndex != -1) {
830                         StatusCode code = inferRequest.inferRequest.Wait(
831                                 InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
832
833                         if (code != StatusCode::OK) {
834                             if (!useHetero) continue;
835                             if (code != StatusCode::INFER_NOT_STARTED) continue;
836                         }
837
838                         if (inferRequest.frameIndex >= 0) {
839                             if (!FLAGS_o.empty()) {
840                                 outputFrame =
841                                         &ptrScores.front() + numScoresPerFrame * sizeof(float) * (inferRequest.frameIndex);
842                                 Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(cOutputInfo.begin()->first);
843                                 auto byteSize = inferRequest.numFramesThisBatch * numScoresPerFrame * sizeof(float);
844                                 std::memcpy(outputFrame,
845                                             outputBlob->buffer(),
846                                             byteSize);
847                             }
848
849                             if (!FLAGS_r.empty()) {
850                                 Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(cOutputInfo.begin()->first);
851                                 CompareScores(outputBlob->buffer().as<float *>(),
852                                               &ptrReferenceScores[inferRequest.frameIndex *
853                                                                   numFrameElementsReference *
854                                                                   numBytesPerElementReference],
855                                               &frameError,
856                                               inferRequest.numFramesThisBatch,
857                                               numFrameElementsReference);
858                                 UpdateScoreError(&frameError, &totalError);
859                             }
860                             if (FLAGS_pc) {
861                                 // retrive new counters
862                                 getPerformanceCounters(inferRequest.inferRequest, callPerfMap);
863                                 // summarize retrived counters with all previous
864                                 sumPerformanceCounters(callPerfMap, utterancePerfMap);
865                             }
866                         }
867                     }
868
869                     if (frameIndex == numFrames) {
870                         inferRequest.frameIndex = -1;
871                         continue;
872                     }
873
874                     ptrInputBlobs.clear();
875                     for (auto& input : cInputInfo) {
876                         ptrInputBlobs.push_back(inferRequest.inferRequest.GetBlob(input.first));
877                     }
878
879                     for (size_t i = 0; i < numInputArkFiles; i++) {
880                         std::memcpy(ptrInputBlobs[i]->buffer(),
881                                     inputFrame[i],
882                                     ptrInputBlobs[i]->byteSize());
883                     }
884
885                     int index = static_cast<int>(frameIndex) - (FLAGS_cw_l + FLAGS_cw_r);
886                     inferRequest.inferRequest.StartAsync();
887                     inferRequest.frameIndex = index < 0 ? -2 : index;
888                     inferRequest.numFramesThisBatch = numFramesThisBatch;
889
890                     frameIndex += numFramesThisBatch;
891                     for (size_t j = 0; j < inputArkFiles.size(); j++) {
892                         if (FLAGS_cw_l > 0 || FLAGS_cw_r > 0) {
893                             int i = frameIndex - FLAGS_cw_l;
894                             if (i > 0 && i < static_cast<int>(numFramesArkFile)) {
895                                 inputFrame[j] += sizeof(float) * numFrameElementsInput[j] * numFramesThisBatch;
896                             } else if (i >= static_cast<int>(numFramesArkFile)) {
897                                 inputFrame[j] = &ptrUtterances[0].front() +
898                                         (numFramesArkFile - 1) * sizeof(float) * numFrameElementsInput[j] * numFramesThisBatch;
899                             } else if (i < 0) {
900                                 inputFrame[j] = &ptrUtterances[0].front();
901                             }
902                         } else {
903                             inputFrame[j] += sizeof(float) * numFrameElementsInput[j] * numFramesThisBatch;
904                         }
905                     }
906                     inferRequestFetched |= true;
907                 }
908
909                 if (!inferRequestFetched) {
910                     std::this_thread::sleep_for(std::chrono::milliseconds(1));
911                     continue;
912                 }
913             }
914             t1 = Time::now();
915
916             fsec fs = t1 - t0;
917             ms d = std::chrono::duration_cast<ms>(fs);
918             totalTime += d.count();
919
920             // resetting state between utterances
921             for (auto &&state : executableNet.QueryState()) {
922                 state.Reset();
923             }
924
925             if (!FLAGS_o.empty()) {
926                 bool shouldAppend = (utteranceIndex == 0) ? false : true;
927                 SaveKaldiArkArray(FLAGS_o.c_str(), shouldAppend, uttName, &ptrScores.front(),
928                                   numFrames, numScoresPerFrame);
929             }
930
931             /** Show performance results **/
932             std::cout << "Total time in Infer (HW and SW):\t" << totalTime << " ms"
933                       << std::endl;
934             std::cout << "Frames in utterance:\t\t\t" << numFrames << " frames"
935                       << std::endl;
936             std::cout << "Average Infer time per frame:\t\t" << totalTime / static_cast<double>(numFrames) << " ms"
937                       << std::endl;
938             if (FLAGS_pc) {
939                 // print
940                 printPerformanceCounters(utterancePerfMap, frameIndex, std::cout, getFullDeviceName(ie, FLAGS_d));
941             }
942             if (!FLAGS_r.empty()) {
943                 printReferenceCompareResults(totalError, numFrames, std::cout);
944             }
945             std::cout << "End of Utterance " << utteranceIndex << std::endl << std::endl;
946         }
947         // -----------------------------------------------------------------------------------------------------
948     }
949     catch (const std::exception &error) {
950         slog::err << error.what() << slog::endl;
951         return 1;
952     }
953     catch (...) {
954         slog::err << "Unknown/internal exception happened" << slog::endl;
955         return 1;
956     }
957
958     slog::info << "Execution successful" << slog::endl;
959     return 0;
960 }