Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / samples / speech_sample / main.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "speech_sample.hpp"
6
7 #include <gflags/gflags.h>
8 #include <functional>
9 #include <iostream>
10 #include <memory>
11 #include <map>
12 #include <fstream>
13 #include <random>
14 #include <string>
15 #include <vector>
16 #include <utility>
17 #include <time.h>
18 #include <thread>
19 #include <chrono>
20 #include <limits>
21 #include <iomanip>
22 #include <inference_engine.hpp>
23 #include <gna/gna_config.hpp>
24
25 #include <samples/common.hpp>
26 #include <samples/slog.hpp>
27 #include <samples/args_helper.hpp>
28 #include <ext_list.hpp>
29
30 #ifndef ALIGN
31 #define ALIGN(memSize, pad)   ((static_cast<int>((memSize) + pad - 1) / pad) * pad)
32 #endif
33 #define MAX_SCORE_DIFFERENCE 0.0001f
34 #define MAX_VAL_2B_FEAT 16384
35
36 using namespace InferenceEngine;
37
38 typedef std::chrono::high_resolution_clock Time;
39 typedef std::chrono::duration<double, std::ratio<1, 1000>> ms;
40 typedef std::chrono::duration<float> fsec;
41 typedef struct {
42     uint32_t numScores;
43     uint32_t numErrors;
44     float threshold;
45     float maxError;
46     float rmsError;
47     float sumError;
48     float sumRmsError;
49     float sumSquaredError;
50     float maxRelError;
51     float sumRelError;
52     float sumSquaredRelError;
53 } score_error_t;
54
55 struct InferRequestStruct {
56     InferRequest inferRequest;
57     int frameIndex;
58     uint32_t numFramesThisBatch;
59 };
60
61 void GetKaldiArkInfo(const char *fileName,
62                      uint32_t numArrayToFindSize,
63                      uint32_t *ptrNumArrays,
64                      uint32_t *ptrNumMemoryBytes) {
65     uint32_t numArrays = 0;
66     uint32_t numMemoryBytes = 0;
67
68     std::ifstream in_file(fileName, std::ios::binary);
69     if (in_file.good()) {
70         while (!in_file.eof()) {
71             std::string line;
72             uint32_t numRows = 0u, numCols = 0u, num_bytes = 0u;
73             std::getline(in_file, line, '\0');  // read variable length name followed by space and NUL
74             std::getline(in_file, line, '\4');  // read "BFM" followed by space and control-D
75             if (line.compare("BFM ") != 0) {
76                 break;
77             }
78             in_file.read(reinterpret_cast<char *>(&numRows), sizeof(uint32_t));  // read number of rows
79             std::getline(in_file, line, '\4');                                   // read control-D
80             in_file.read(reinterpret_cast<char *>(&numCols), sizeof(uint32_t));  // read number of columns
81             num_bytes = numRows * numCols * sizeof(float);
82             in_file.seekg(num_bytes, in_file.cur);                               // read data
83
84             if (numArrays == numArrayToFindSize) {
85                 numMemoryBytes += num_bytes;
86             }
87             numArrays++;
88         }
89         in_file.close();
90     } else {
91         fprintf(stderr, "Failed to open %s for reading in GetKaldiArkInfo()!\n", fileName);
92         exit(-1);
93     }
94
95     if (ptrNumArrays != NULL) *ptrNumArrays = numArrays;
96     if (ptrNumMemoryBytes != NULL) *ptrNumMemoryBytes = numMemoryBytes;
97 }
98
99 void LoadKaldiArkArray(const char *fileName, uint32_t arrayIndex, std::string &ptrName, std::vector<uint8_t> &memory,
100                        uint32_t *ptrNumRows, uint32_t *ptrNumColumns, uint32_t *ptrNumBytesPerElement) {
101     std::ifstream in_file(fileName, std::ios::binary);
102     if (in_file.good()) {
103         uint32_t i = 0;
104         while (i < arrayIndex) {
105             std::string line;
106             uint32_t numRows = 0u, numCols = 0u;
107             std::getline(in_file, line, '\0');  // read variable length name followed by space and NUL
108             std::getline(in_file, line, '\4');  // read "BFM" followed by space and control-D
109             if (line.compare("BFM ") != 0) {
110                 break;
111             }
112             in_file.read(reinterpret_cast<char *>(&numRows), sizeof(uint32_t));     // read number of rows
113             std::getline(in_file, line, '\4');                                     // read control-D
114             in_file.read(reinterpret_cast<char *>(&numCols), sizeof(uint32_t));     // read number of columns
115             in_file.seekg(numRows * numCols * sizeof(float), in_file.cur);         // read data
116             i++;
117         }
118         if (!in_file.eof()) {
119             std::string line;
120             std::getline(in_file, ptrName, '\0');     // read variable length name followed by space and NUL
121             std::getline(in_file, line, '\4');       // read "BFM" followed by space and control-D
122             if (line.compare("BFM ") != 0) {
123                 fprintf(stderr, "Cannot find array specifier in file %s in LoadKaldiArkArray()!\n", fileName);
124                 exit(-1);
125             }
126             in_file.read(reinterpret_cast<char *>(ptrNumRows), sizeof(uint32_t));        // read number of rows
127             std::getline(in_file, line, '\4');                                            // read control-D
128             in_file.read(reinterpret_cast<char *>(ptrNumColumns), sizeof(uint32_t));    // read number of columns
129             in_file.read(reinterpret_cast<char *>(&memory.front()),
130                          *ptrNumRows * *ptrNumColumns * sizeof(float));  // read array data
131         }
132         in_file.close();
133     } else {
134         fprintf(stderr, "Failed to open %s for reading in GetKaldiArkInfo()!\n", fileName);
135         exit(-1);
136     }
137
138     *ptrNumBytesPerElement = sizeof(float);
139 }
140
141 void SaveKaldiArkArray(const char *fileName,
142                        bool shouldAppend,
143                        std::string name,
144                        void *ptrMemory,
145                        uint32_t numRows,
146                        uint32_t numColumns) {
147     std::ios_base::openmode mode = std::ios::binary;
148     if (shouldAppend) {
149         mode |= std::ios::app;
150     }
151     std::ofstream out_file(fileName, mode);
152     if (out_file.good()) {
153         out_file.write(name.c_str(), name.length());  // write name
154         out_file.write("\0", 1);
155         out_file.write("BFM ", 4);
156         out_file.write("\4", 1);
157         out_file.write(reinterpret_cast<char *>(&numRows), sizeof(uint32_t));
158         out_file.write("\4", 1);
159         out_file.write(reinterpret_cast<char *>(&numColumns), sizeof(uint32_t));
160         out_file.write(reinterpret_cast<char *>(ptrMemory), numRows * numColumns * sizeof(float));
161         out_file.close();
162     } else {
163         throw std::runtime_error(std::string("Failed to open %s for writing in SaveKaldiArkArray()!\n") + fileName);
164     }
165 }
166
167 float ScaleFactorForQuantization(void *ptrFloatMemory, float targetMax, uint32_t numElements) {
168     float *ptrFloatFeat = reinterpret_cast<float *>(ptrFloatMemory);
169     float max = 0.0;
170     float scaleFactor;
171
172     for (uint32_t i = 0; i < numElements; i++) {
173         if (fabs(ptrFloatFeat[i]) > max) {
174             max = fabs(ptrFloatFeat[i]);
175         }
176     }
177
178     if (max == 0) {
179         scaleFactor = 1.0;
180     } else {
181         scaleFactor = targetMax / max;
182     }
183
184     return (scaleFactor);
185 }
186
187 void ClearScoreError(score_error_t *error) {
188     error->numScores = 0;
189     error->numErrors = 0;
190     error->maxError = 0.0;
191     error->rmsError = 0.0;
192     error->sumError = 0.0;
193     error->sumRmsError = 0.0;
194     error->sumSquaredError = 0.0;
195     error->maxRelError = 0.0;
196     error->sumRelError = 0.0;
197     error->sumSquaredRelError = 0.0;
198 }
199
200 void UpdateScoreError(score_error_t *error, score_error_t *totalError) {
201     totalError->numErrors += error->numErrors;
202     totalError->numScores += error->numScores;
203     totalError->sumRmsError += error->rmsError;
204     totalError->sumError += error->sumError;
205     totalError->sumSquaredError += error->sumSquaredError;
206     if (error->maxError > totalError->maxError) {
207         totalError->maxError = error->maxError;
208     }
209     totalError->sumRelError += error->sumRelError;
210     totalError->sumSquaredRelError += error->sumSquaredRelError;
211     if (error->maxRelError > totalError->maxRelError) {
212         totalError->maxRelError = error->maxRelError;
213     }
214 }
215
216 uint32_t CompareScores(float *ptrScoreArray,
217                        void *ptrRefScoreArray,
218                        score_error_t *scoreError,
219                        uint32_t numRows,
220                        uint32_t numColumns) {
221     uint32_t numErrors = 0;
222
223     ClearScoreError(scoreError);
224
225     float *A = ptrScoreArray;
226     float *B = reinterpret_cast<float *>(ptrRefScoreArray);
227     for (uint32_t i = 0; i < numRows; i++) {
228         for (uint32_t j = 0; j < numColumns; j++) {
229             float score = A[i * numColumns + j];
230             float refscore = B[i * numColumns + j];
231             float error = fabs(refscore - score);
232             float rel_error = error / (static_cast<float>(fabs(refscore)) + 1e-20f);
233             float squared_error = error * error;
234             float squared_rel_error = rel_error * rel_error;
235             scoreError->numScores++;
236             scoreError->sumError += error;
237             scoreError->sumSquaredError += squared_error;
238             if (error > scoreError->maxError) {
239                 scoreError->maxError = error;
240             }
241             scoreError->sumRelError += rel_error;
242             scoreError->sumSquaredRelError += squared_rel_error;
243             if (rel_error > scoreError->maxRelError) {
244                 scoreError->maxRelError = rel_error;
245             }
246             if (error > scoreError->threshold) {
247                 numErrors++;
248             }
249         }
250     }
251     scoreError->rmsError = sqrt(scoreError->sumSquaredError / (numRows * numColumns));
252     scoreError->sumRmsError += scoreError->rmsError;
253     scoreError->numErrors = numErrors;
254
255     return (numErrors);
256 }
257
258 float StdDevError(score_error_t error) {
259     return (sqrt(error.sumSquaredError / error.numScores
260                  - (error.sumError / error.numScores) * (error.sumError / error.numScores)));
261 }
262
263 float StdDevRelError(score_error_t error) {
264     return (sqrt(error.sumSquaredRelError / error.numScores
265                  - (error.sumRelError / error.numScores) * (error.sumRelError / error.numScores)));
266 }
267
268 #if !defined(__arm__) && !defined(_M_ARM)
269 #if defined(_WIN32) || defined(WIN32)
270 #include <intrin.h>
271 #include <windows.h>
272 #else
273
274 #include <cpuid.h>
275
276 #endif
277
278 inline void native_cpuid(unsigned int *eax, unsigned int *ebx,
279                          unsigned int *ecx, unsigned int *edx) {
280     size_t level = *eax;
281 #if defined(_WIN32) || defined(WIN32)
282     int regs[4] = {static_cast<int>(*eax), static_cast<int>(*ebx), static_cast<int>(*ecx), static_cast<int>(*edx)};
283     __cpuid(regs, level);
284     *eax = static_cast<uint32_t>(regs[0]);
285     *ebx = static_cast<uint32_t>(regs[1]);
286     *ecx = static_cast<uint32_t>(regs[2]);
287     *edx = static_cast<uint32_t>(regs[3]);
288 #else
289     __get_cpuid(level, eax, ebx, ecx, edx);
290 #endif
291 }
292
293 // return GNA module frequency in MHz
294 float getGnaFrequencyMHz() {
295     uint32_t eax = 1;
296     uint32_t ebx = 0;
297     uint32_t ecx = 0;
298     uint32_t edx = 0;
299     uint32_t family = 0;
300     uint32_t model = 0;
301     const uint8_t sixth_family = 6;
302     const uint8_t cannon_lake_model = 102;
303     const uint8_t gemini_lake_model = 122;
304
305     native_cpuid(&eax, &ebx, &ecx, &edx);
306     family = (eax >> 8) & 0xF;
307
308     // model is the concatenation of two fields
309     // | extended model | model |
310     // copy extended model data
311     model = (eax >> 16) & 0xF;
312     // shift
313     model <<= 4;
314     // copy model data
315     model += (eax >> 4) & 0xF;
316
317     if (family == sixth_family && model == cannon_lake_model) {
318         return 400;
319     } else if (family == sixth_family &&
320                model == gemini_lake_model) {
321         return 200;
322     } else {
323         // counters not supported and we retrns just default value
324         return 1;
325     }
326 }
327
328 #endif  // !defined(__arm__) && !defined(_M_ARM)
329
330 void printReferenceCompareResults(score_error_t const &totalError,
331                                   size_t framesNum,
332                                   std::ostream &stream) {
333     stream << "         max error: " <<
334            totalError.maxError << std::endl;
335     stream << "         avg error: " <<
336            totalError.sumError / totalError.numScores << std::endl;
337     stream << "     avg rms error: " <<
338            totalError.sumRmsError / framesNum << std::endl;
339     stream << "       stdev error: " <<
340            StdDevError(totalError) << std::endl << std::endl;
341     stream << std::endl;
342 }
343
344 void printPerformanceCounters(std::map<std::string,
345         InferenceEngine::InferenceEngineProfileInfo> const &utterancePerfMap,
346                               size_t callsNum,
347                               std::ostream &stream) {
348 #if !defined(__arm__) && !defined(_M_ARM)
349     stream << std::endl << "Performance counts:" << std::endl;
350     stream << std::setw(10) << std::right << "" << "Counter descriptions";
351     stream << std::setw(22) << "Utt scoring time";
352     stream << std::setw(18) << "Avg infer time";
353     stream << std::endl;
354
355     stream << std::setw(46) << "(ms)";
356     stream << std::setw(24) << "(us per call)";
357     stream << std::endl;
358
359     for (const auto &it : utterancePerfMap) {
360         std::string const &counter_name = it.first;
361         float current_units = static_cast<float>(it.second.realTime_uSec);
362         float call_units = current_units / callsNum;
363         // if GNA HW counters
364         // get frequency of GNA module
365         float freq = getGnaFrequencyMHz();
366         current_units /= freq * 1000;
367         call_units /= freq;
368         stream << std::setw(30) << std::left << counter_name.substr(4, counter_name.size() - 1);
369         stream << std::setw(16) << std::right << current_units;
370         stream << std::setw(21) << std::right << call_units;
371         stream << std::endl;
372     }
373     stream << std::endl;
374 #endif
375 }
376
377 void getPerformanceCounters(InferenceEngine::InferRequest &request,
378                             std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &perfCounters) {
379     auto retPerfCounters = request.GetPerformanceCounts();
380
381     for (const auto &pair : retPerfCounters) {
382         perfCounters[pair.first] = pair.second;
383     }
384 }
385
386 void sumPerformanceCounters(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> const &perfCounters,
387                             std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &totalPerfCounters) {
388     for (const auto &pair : perfCounters) {
389         totalPerfCounters[pair.first].realTime_uSec += pair.second.realTime_uSec;
390     }
391 }
392
393 bool ParseAndCheckCommandLine(int argc, char *argv[]) {
394     // ---------------------------Parsing and validation of input args--------------------------------------
395     slog::info << "Parsing input parameters" << slog::endl;
396
397     gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
398     if (FLAGS_h) {
399         showUsage();
400         return false;
401     }
402     bool isDumpMode = !FLAGS_wg.empty() || !FLAGS_we.empty();
403
404     // input not required only in dump mode and if external scale factor provided
405     if (FLAGS_i.empty() && (!isDumpMode || FLAGS_q.compare("user") != 0)) {
406         if (isDumpMode) {
407             throw std::logic_error("In model dump mode either static quantization is used (-i) or user scale"
408                                    " factor need to be provided. See -q user option");
409         }
410         throw std::logic_error("Input file not set. Please use -i.");
411     }
412
413     if (FLAGS_m.empty() && FLAGS_rg.empty()) {
414         throw std::logic_error("Either IR file (-m) or GNAModel file (-rg) need to be set.");
415     }
416
417     if ((!FLAGS_m.empty() && !FLAGS_rg.empty())) {
418         throw std::logic_error("Only one of -m and -rg is allowed.");
419     }
420
421     std::vector<std::string> possibleDeviceTypes = {
422             "CPU",
423             "GPU",
424             "GNA_AUTO",
425             "GNA_HW",
426             "GNA_SW_EXACT",
427             "GNA_SW",
428             "HETERO:GNA,CPU",
429             "HETERO:GNA_HW,CPU",
430             "HETERO:GNA_SW_EXACT,CPU",
431             "HETERO:GNA_SW,CPU",
432     };
433
434     if (std::find(possibleDeviceTypes.begin(), possibleDeviceTypes.end(), FLAGS_d) == possibleDeviceTypes.end()) {
435         throw std::logic_error("Specified device is not supported.");
436     }
437
438     float scaleFactorInput = static_cast<float>(FLAGS_sf);
439     if (scaleFactorInput <= 0.0f) {
440         throw std::logic_error("Scale factor out of range (must be non-negative).");
441     }
442
443     uint32_t batchSize = (uint32_t) FLAGS_bs;
444     if ((batchSize < 1) || (batchSize > 8)) {
445         throw std::logic_error("Batch size out of range (1..8).");
446     }
447
448     /** default is a static quantisation **/
449     if ((FLAGS_q.compare("static") != 0) && (FLAGS_q.compare("dynamic") != 0) && (FLAGS_q.compare("user") != 0)) {
450         throw std::logic_error("Quantization mode not supported (static, dynamic, user).");
451     }
452
453     if (FLAGS_q.compare("dynamic") == 0) {
454         throw std::logic_error("Dynamic quantization not yet supported.");
455     }
456
457     if (FLAGS_qb != 16 && FLAGS_qb != 8) {
458         throw std::logic_error("Only 8 or 16 bits supported.");
459     }
460
461     if (FLAGS_nthreads <= 0) {
462         throw std::logic_error("Not valid value for 'nthreads' argument. It should be > 0 ");
463     }
464
465     if (FLAGS_cw < 0) {
466         throw std::logic_error("Not valid value for 'cw' argument. It should be > 0 ");
467     }
468
469     return true;
470 }
471
472 /**
473  * @brief The entry point for inference engine automatic speech recognition sample
474  * @file speech_sample/main.cpp
475  * @example speech_sample/main.cpp
476  */
477 int main(int argc, char *argv[]) {
478     try {
479         slog::info << "InferenceEngine: " << GetInferenceEngineVersion() << slog::endl;
480
481         // ------------------------------ Parsing and validation of input args ---------------------------------
482         if (!ParseAndCheckCommandLine(argc, argv)) {
483             return 0;
484         }
485
486         if (FLAGS_l.empty()) {
487             slog::info << "No extensions provided" << slog::endl;
488         }
489
490         auto isFeature = [&](const std::string xFeature) { return FLAGS_d.find(xFeature) != std::string::npos; };
491
492         bool useGna = isFeature("GNA");
493         bool useHetero = isFeature("HETERO");
494         std::string deviceStr =
495                 useHetero && useGna ? "HETERO:GNA,CPU" : FLAGS_d.substr(0, (FLAGS_d.find("_")));
496         float scaleFactorInput = static_cast<float>(FLAGS_sf);
497         uint32_t batchSize = FLAGS_cw > 0 ? 1 : (uint32_t) FLAGS_bs;
498         /** Extract input ark file name **/
499         std::string inputArkName = fileNameNoExt(FLAGS_i) + ".ark";
500
501         uint32_t numUtterances(0), numBytesThisUtterance(0);
502         if (!FLAGS_i.empty()) {
503             GetKaldiArkInfo(inputArkName.c_str(), 0, &numUtterances, &numBytesThisUtterance);
504         }
505         // -----------------------------------------------------------------------------------------------------
506
507         // --------------------------- 1. Load Plugin for inference engine -------------------------------------
508         slog::info << "Loading plugin" << slog::endl;
509         /** Loading plugin for device **/
510         InferencePlugin plugin = PluginDispatcher({FLAGS_pp}).getPluginByDevice(deviceStr);
511
512         /** Printing plugin version **/
513         std::cout << plugin.GetVersion() << std::endl << std::endl;
514         // -----------------------------------------------------------------------------------------------------
515
516         // --------------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ------------
517         slog::info << "Loading network files" << slog::endl;
518
519         CNNNetReader netBuilder;
520         if (!FLAGS_m.empty()) {
521             /** Read network model **/
522             netBuilder.ReadNetwork(FLAGS_m);
523
524             /** Extract model name and load weights **/
525             std::string binFileName = fileNameNoExt(FLAGS_m) + ".bin";
526             netBuilder.ReadWeights(binFileName);
527
528             // -------------------------------------------------------------------------------------------------
529
530             // --------------------------- 3. Set batch size ---------------------------------------------------
531             /** Set batch size.  Unlike in imaging, batching in time (rather than space) is done for speech recognition. **/
532             netBuilder.getNetwork().setBatchSize(batchSize);
533             slog::info << "Batch size is " << std::to_string(netBuilder.getNetwork().getBatchSize())
534                        << slog::endl;
535         }
536
537         /** Setting plugin parameter for per layer metrics **/
538         std::map<std::string, std::string> gnaPluginConfig;
539         std::map<std::string, std::string> genericPluginConfig;
540         if (useGna) {
541             std::string gnaDevice =
542                     useHetero ? FLAGS_d.substr(FLAGS_d.find("GNA"), FLAGS_d.find(",") - FLAGS_d.find("GNA")) : FLAGS_d;
543             gnaPluginConfig[GNAConfigParams::KEY_GNA_DEVICE_MODE] =
544                     gnaDevice.find("_") == std::string::npos ? "GNA_AUTO" : gnaDevice;
545         } else if (plugin.GetVersion()->description == std::string("MKLDNNPlugin")) {
546             /**
547              * cpu_extensions library is compiled from "extension" folder containing
548              * custom MKLDNNPlugin layer implementations. These layers are not supported
549              * by mkldnn, but they can be useful for inferring custom topologies.
550             **/
551             plugin.AddExtension(std::make_shared<Extensions::Cpu::CpuExtensions>());
552         }
553
554         if (FLAGS_pc) {
555             genericPluginConfig[PluginConfigParams::KEY_PERF_COUNT] = PluginConfigParams::YES;
556         }
557
558         if (FLAGS_q.compare("user") == 0) {
559             std::cout << "[ INFO ] Using scale factor of " << FLAGS_sf << std::endl;
560             gnaPluginConfig[GNA_CONFIG_KEY(SCALE_FACTOR)] = std::to_string(FLAGS_sf);
561         } else {  // "static" quantization with calculated scale factor
562             std::string name;
563             std::vector<uint8_t> ptrFeatures;
564             uint32_t numArrays(0), numBytes(0), numFrames(0), numFrameElements(0), numBytesPerElement(0);
565             GetKaldiArkInfo(inputArkName.c_str(), 0, &numArrays, &numBytes);
566             ptrFeatures.resize(numBytes);
567             LoadKaldiArkArray(inputArkName.c_str(),
568                               0,
569                               name,
570                               ptrFeatures,
571                               &numFrames,
572                               &numFrameElements,
573                               &numBytesPerElement);
574             scaleFactorInput =
575                     ScaleFactorForQuantization(ptrFeatures.data(), MAX_VAL_2B_FEAT, numFrames * numFrameElements);
576             slog::info << "Using scale factor of " << scaleFactorInput << " calculated from first utterance."
577                        << slog::endl;
578             gnaPluginConfig[GNA_CONFIG_KEY(SCALE_FACTOR)] = std::to_string(scaleFactorInput);
579         }
580
581         if (FLAGS_qb == 8) {
582             gnaPluginConfig[GNAConfigParams::KEY_GNA_PRECISION] = "I8";
583         } else {
584             gnaPluginConfig[GNAConfigParams::KEY_GNA_PRECISION] = "I16";
585         }
586
587         gnaPluginConfig[GNAConfigParams::KEY_GNA_LIB_N_THREADS] = std::to_string(FLAGS_cw > 0 ? 1 : FLAGS_nthreads);
588         gnaPluginConfig[GNA_CONFIG_KEY(COMPACT_MODE)] = CONFIG_VALUE(NO);
589         // -----------------------------------------------------------------------------------------------------
590
591         // --------------------------- 4. Write model to file --------------------------------------------------
592         // Embedded GNA model dumping (for Intel(R) Speech Enabling Developer Kit)
593         if (!FLAGS_we.empty()) {
594             gnaPluginConfig[GNAConfigParams::KEY_GNA_FIRMWARE_MODEL_IMAGE] = FLAGS_we;
595         }
596         // -----------------------------------------------------------------------------------------------------
597
598         // --------------------------- 5. Loading model to the plugin ------------------------------------------
599
600         if (useGna) {
601             genericPluginConfig.insert(std::begin(gnaPluginConfig), std::end(gnaPluginConfig));
602         }
603         auto t0 = Time::now();
604         ExecutableNetwork executableNet;
605
606         if (!FLAGS_m.empty()) {
607             slog::info << "Loading model to the plugin" << slog::endl;
608             executableNet = plugin.LoadNetwork(netBuilder.getNetwork(), genericPluginConfig);
609         } else {
610             slog::info << "Importing model to the plugin" << slog::endl;
611             executableNet = plugin.ImportNetwork(FLAGS_rg.c_str(), genericPluginConfig);
612         }
613
614         ms loadTime = std::chrono::duration_cast<ms>(Time::now() - t0);
615         slog::info << "Model loading time " << loadTime.count() << " ms" << slog::endl;
616
617         // --------------------------- 6. Exporting gna model using InferenceEngine AOT API---------------------
618         if (!FLAGS_wg.empty()) {
619             slog::info << "Writing GNA Model to file " << FLAGS_wg << slog::endl;
620             t0 = Time::now();
621             executableNet.Export(FLAGS_wg);
622             ms exportTime = std::chrono::duration_cast<ms>(Time::now() - t0);
623             slog::info << "Exporting time " << exportTime.count() << " ms" << slog::endl;
624             return 0;
625         }
626
627         if (!FLAGS_we.empty()) {
628             slog::info << "Exported GNA embedded model to file " << FLAGS_we << slog::endl;
629             return 0;
630         }
631
632         std::vector<InferRequestStruct> inferRequests(FLAGS_cw > 0 ? 1 : FLAGS_nthreads);
633         for (auto& inferRequest : inferRequests) {
634             inferRequest = {executableNet.CreateInferRequest(), -1, batchSize};
635         }
636         // -----------------------------------------------------------------------------------------------------
637
638         // --------------------------- 7. Prepare input blobs --------------------------------------------------
639         /** Taking information about all topology inputs **/
640         ConstInputsDataMap cInputInfo = executableNet.GetInputsInfo();
641         InputsDataMap inputInfo;
642         if (!FLAGS_m.empty()) {
643             inputInfo = netBuilder.getNetwork().getInputsInfo();
644         }
645
646         /** Stores all input blobs data **/
647         if (cInputInfo.size() != 1) {
648             throw std::logic_error("Sample supports only topologies with  1 input");
649         }
650
651         Blob::Ptr ptrInputBlob = inferRequests[0].inferRequest.GetBlob(cInputInfo.begin()->first);
652
653         /** configure input precision if model loaded from IR **/
654         for (auto &item : inputInfo) {
655             Precision inputPrecision = Precision::FP32;  // specify Precision::I16 to provide quantized inputs
656             item.second->setPrecision(inputPrecision);
657             item.second->getInputData()->layout = NC;  // row major layout
658         }
659
660         // -----------------------------------------------------------------------------------------------------
661
662         // --------------------------- 8. Prepare output blobs -------------------------------------------------
663         ConstOutputsDataMap cOutputInfo(executableNet.GetOutputsInfo());
664         OutputsDataMap outputInfo;
665         if (!FLAGS_m.empty()) {
666             outputInfo = netBuilder.getNetwork().getOutputsInfo();
667         }
668
669         Blob::Ptr ptrOutputBlob = inferRequests[0].inferRequest.GetBlob(cOutputInfo.begin()->first);
670
671         for (auto &item : outputInfo) {
672             DataPtr outData = item.second;
673             if (!outData) {
674                 throw std::logic_error("output data pointer is not valid");
675             }
676
677             Precision outputPrecision = Precision::FP32;  // specify Precision::I32 to retrieve quantized outputs
678             outData->setPrecision(outputPrecision);
679             outData->layout = NC;  // row major layout
680         }
681         // -----------------------------------------------------------------------------------------------------
682
683         // --------------------------- 9. Do inference ---------------------------------------------------------
684         std::vector<uint8_t> ptrUtterance;
685         std::vector<uint8_t> ptrScores;
686         std::vector<uint8_t> ptrReferenceScores;
687         score_error_t frameError, totalError;
688
689         for (uint32_t utteranceIndex = 0; utteranceIndex < numUtterances; ++utteranceIndex) {
690             std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> utterancePerfMap;
691             std::string uttName;
692             uint32_t numFrames(0), numFrameElementsInput(0), numBytesPerElementInput(0), n(0);
693             uint32_t numFramesReference(0), numFrameElementsReference(0), numBytesPerElementReference(0),
694                     numBytesReferenceScoreThisUtterance(0);
695             const uint32_t numScoresPerFrame = ptrOutputBlob->size() / batchSize;
696             GetKaldiArkInfo(inputArkName.c_str(), utteranceIndex, &n, &numBytesThisUtterance);
697             ptrUtterance.resize(numBytesThisUtterance);
698             LoadKaldiArkArray(inputArkName.c_str(),
699                               utteranceIndex,
700                               uttName,
701                               ptrUtterance,
702                               &numFrames,
703                               &numFrameElementsInput,
704                               &numBytesPerElementInput);
705
706             uint32_t numFrameElementsInputPadded = numFrameElementsInput;
707
708             if (ptrInputBlob->size() != numFrameElementsInputPadded * batchSize) {
709                 throw std::logic_error("network input size(" + std::to_string(ptrInputBlob->size()) +
710                                        ") mismatch to ark file size (" +
711                                        std::to_string(numFrameElementsInputPadded * batchSize) + ")");
712             }
713             ptrScores.resize(numFrames * numScoresPerFrame * sizeof(float));
714             if (!FLAGS_r.empty()) {
715                 std::string refUtteranceName;
716                 GetKaldiArkInfo(FLAGS_r.c_str(), utteranceIndex, &n, &numBytesReferenceScoreThisUtterance);
717                 ptrReferenceScores.resize(numBytesReferenceScoreThisUtterance);
718                 LoadKaldiArkArray(FLAGS_r.c_str(),
719                                   utteranceIndex,
720                                   refUtteranceName,
721                                   ptrReferenceScores,
722                                   &numFramesReference,
723                                   &numFrameElementsReference,
724                                   &numBytesPerElementReference);
725             }
726
727             double totalTime = 0.0;
728
729             std::cout << "Utterance " << utteranceIndex << ": " << std::endl;
730
731             ClearScoreError(&totalError);
732             totalError.threshold = frameError.threshold = MAX_SCORE_DIFFERENCE;
733             auto inputFrame = &ptrUtterance.front();
734             auto outputFrame = &ptrScores.front();
735
736             std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> callPerfMap;
737
738             size_t frameIndex = 0;
739             numFrames += 2 * FLAGS_cw;
740             uint32_t numFramesThisBatch{batchSize};
741
742             auto t0 = Time::now();
743             auto t1 = t0;
744
745             while (frameIndex <= numFrames) {
746                 if (frameIndex == numFrames) {
747                     if (std::find_if(inferRequests.begin(),
748                             inferRequests.end(),
749                             [&](InferRequestStruct x) { return (x.frameIndex != -1); } ) == inferRequests.end()) {
750                         break;
751                     }
752                 }
753
754                 bool inferRequestFetched = false;
755                 for (auto &inferRequest : inferRequests) {
756                     if (frameIndex == numFrames) {
757                         numFramesThisBatch = 1;
758                     } else {
759                         numFramesThisBatch = (numFrames - frameIndex < batchSize) ? (numFrames - frameIndex)
760                                                                                   : batchSize;
761                     }
762
763                     if (inferRequest.frameIndex != -1) {
764                         StatusCode code = inferRequest.inferRequest.Wait(
765                                 InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
766                         if (code != StatusCode::OK) {
767                             if (!useHetero) continue;
768                             if (code != StatusCode::INFER_NOT_STARTED) continue;
769                         }
770
771                         if (inferRequest.frameIndex >= 0) {
772                             if (!FLAGS_o.empty()) {
773                                 outputFrame =
774                                         &ptrScores.front() + numScoresPerFrame * sizeof(float) * (inferRequest.frameIndex);
775                                 Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(cOutputInfo.begin()->first);
776                                 auto byteSize = inferRequest.numFramesThisBatch * numScoresPerFrame * sizeof(float);
777                                 std::memcpy(outputFrame,
778                                             outputBlob->buffer(),
779                                             byteSize);
780                             }
781
782                             if (!FLAGS_r.empty()) {
783                                 Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(cOutputInfo.begin()->first);
784                                 CompareScores(outputBlob->buffer().as<float *>(),
785                                               &ptrReferenceScores[inferRequest.frameIndex *
786                                                                   numFrameElementsReference *
787                                                                   numBytesPerElementReference],
788                                               &frameError,
789                                               inferRequest.numFramesThisBatch,
790                                               numFrameElementsReference);
791                                 UpdateScoreError(&frameError, &totalError);
792                             }
793                             if (FLAGS_pc) {
794                                 // retrive new counters
795                                 getPerformanceCounters(inferRequest.inferRequest, callPerfMap);
796                                 // summarize retrived counters with all previous
797                                 sumPerformanceCounters(callPerfMap, utterancePerfMap);
798                             }
799                         }
800                     }
801
802                     if (frameIndex == numFrames) {
803                         inferRequest.frameIndex = -1;
804                         continue;
805                     }
806
807                     Blob::Ptr inputBlob = inferRequest.inferRequest.GetBlob(cInputInfo.begin()->first);
808
809                     std::memcpy(inputBlob->buffer(),
810                                 inputFrame,
811                                 inputBlob->byteSize());
812
813                     auto index = frameIndex - 2 * FLAGS_cw;
814                     inferRequest.inferRequest.StartAsync();
815                     inferRequest.frameIndex = index < 0 ? -2 : index;
816                     inferRequest.numFramesThisBatch = numFramesThisBatch;
817
818                     frameIndex += numFramesThisBatch;
819
820                     if (FLAGS_cw > 0) {
821                         int i = frameIndex - FLAGS_cw;
822                         if (i > 0 && i < static_cast<int>(numFrames)) {
823                             inputFrame += sizeof(float) * numFrameElementsInput * numFramesThisBatch;
824                         } else if (i >= static_cast<int>(numFrames)) {
825                             inputFrame = &ptrUtterance.front() +
826                                          (numFrames - 1) * sizeof(float) * numFrameElementsInput *
827                                          numFramesThisBatch;
828                         }
829                     } else {
830                         inputFrame += sizeof(float) * numFrameElementsInput * numFramesThisBatch;
831                     }
832                     inferRequestFetched |= true;
833                 }
834
835                 if (!inferRequestFetched) {
836                     std::this_thread::sleep_for(std::chrono::milliseconds(1));
837                     continue;
838                 }
839             }
840             t1 = Time::now();
841
842             fsec fs = t1 - t0;
843             ms d = std::chrono::duration_cast<ms>(fs);
844             totalTime += d.count();
845
846             // resetting state between utterances
847             for (auto &&state : executableNet.QueryState()) {
848                 state.Reset();
849             }
850
851             if (!FLAGS_o.empty()) {
852                 bool shouldAppend = (utteranceIndex == 0) ? false : true;
853                 SaveKaldiArkArray(FLAGS_o.c_str(), shouldAppend, uttName, &ptrScores.front(),
854                                   numFrames, numScoresPerFrame);
855             }
856
857             /** Show performance results **/
858             std::cout << "Total time in Infer (HW and SW):\t" << totalTime << " ms"
859                       << std::endl;
860             std::cout << "Frames in utterance:\t\t\t" << numFrames << " frames"
861                       << std::endl;
862             std::cout << "Average Infer time per frame:\t\t" << totalTime / static_cast<double>(numFrames) << " ms"
863                       << std::endl;
864             if (FLAGS_pc) {
865                 // print
866                 printPerformanceCounters(utterancePerfMap, frameIndex, std::cout);
867             }
868             if (!FLAGS_r.empty()) {
869                 printReferenceCompareResults(totalError, numFrames, std::cout);
870             }
871             std::cout << "End of Utterance " << utteranceIndex << std::endl << std::endl;
872         }
873         // -----------------------------------------------------------------------------------------------------
874     }
875     catch (const std::exception &error) {
876         slog::err << error.what() << slog::endl;
877         return 1;
878     }
879     catch (...) {
880         slog::err << "Unknown/internal exception happened" << slog::endl;
881         return 1;
882     }
883
884     slog::info << "Execution successful" << slog::endl;
885     return 0;
886 }