1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "speech_sample.hpp"
6 #include "rockhopper_decoder.h"
8 #include <gflags/gflags.h>
23 #include <inference_engine.hpp>
24 #include <gna/gna_config.hpp>
26 #include <samples/common.hpp>
27 #include <samples/slog.hpp>
28 #include <samples/args_helper.hpp>
31 #define ALIGN(memSize, pad) ((static_cast<int>((memSize) + pad - 1) / pad) * pad)
33 #define MAX_SCORE_DIFFERENCE 0.0001f
34 #define MAX_VAL_2B_FEAT 16384
36 using namespace InferenceEngine;
38 typedef std::chrono::high_resolution_clock Time;
39 typedef std::chrono::duration<double, std::ratio<1, 1000>> ms;
40 typedef std::chrono::duration<float> fsec;
49 float sumSquaredError;
52 float sumSquaredRelError;
55 struct InferRequestStruct {
56 InferRequest inferRequest;
58 uint32_t numFramesThisBatch;
61 struct RhDecoderInstanceParams {
62 RhDecoderInstanceHandle handle;
69 void GetKaldiArkInfo(const char *fileName,
70 uint32_t numArrayToFindSize,
71 uint32_t *ptrNumArrays,
72 uint32_t *ptrNumMemoryBytes) {
73 uint32_t numArrays = 0;
74 uint32_t numMemoryBytes = 0;
76 std::ifstream in_file(fileName, std::ios::binary);
78 while (!in_file.eof()) {
80 uint32_t numRows = 0u, numCols = 0u, num_bytes = 0u;
81 std::getline(in_file, line, '\0'); // read variable length name followed by space and NUL
82 std::getline(in_file, line, '\4'); // read "BFM" followed by space and control-D
83 if (line.compare("BFM ") != 0) {
86 in_file.read(reinterpret_cast<char *>(&numRows), sizeof(uint32_t)); // read number of rows
87 std::getline(in_file, line, '\4'); // read control-D
88 in_file.read(reinterpret_cast<char *>(&numCols), sizeof(uint32_t)); // read number of columns
89 num_bytes = numRows * numCols * sizeof(float);
90 in_file.seekg(num_bytes, in_file.cur); // read data
92 if (numArrays == numArrayToFindSize) {
93 numMemoryBytes += num_bytes;
99 fprintf(stderr, "Failed to open %s for reading in GetKaldiArkInfo()!\n", fileName);
103 if (ptrNumArrays != NULL) *ptrNumArrays = numArrays;
104 if (ptrNumMemoryBytes != NULL) *ptrNumMemoryBytes = numMemoryBytes;
107 void LoadKaldiArkArray(const char *fileName, uint32_t arrayIndex, std::string &ptrName, std::vector<uint8_t> &memory,
108 uint32_t *ptrNumRows, uint32_t *ptrNumColumns, uint32_t *ptrNumBytesPerElement) {
109 std::ifstream in_file(fileName, std::ios::binary);
110 if (in_file.good()) {
112 while (i < arrayIndex) {
114 uint32_t numRows = 0u, numCols = 0u;
115 std::getline(in_file, line, '\0'); // read variable length name followed by space and NUL
116 std::getline(in_file, line, '\4'); // read "BFM" followed by space and control-D
117 if (line.compare("BFM ") != 0) {
120 in_file.read(reinterpret_cast<char *>(&numRows), sizeof(uint32_t)); // read number of rows
121 std::getline(in_file, line, '\4'); // read control-D
122 in_file.read(reinterpret_cast<char *>(&numCols), sizeof(uint32_t)); // read number of columns
123 in_file.seekg(numRows * numCols * sizeof(float), in_file.cur); // read data
126 if (!in_file.eof()) {
128 std::getline(in_file, ptrName, '\0'); // read variable length name followed by space and NUL
129 std::getline(in_file, line, '\4'); // read "BFM" followed by space and control-D
130 if (line.compare("BFM ") != 0) {
131 fprintf(stderr, "Cannot find array specifier in file %s in LoadKaldiArkArray()!\n", fileName);
134 in_file.read(reinterpret_cast<char *>(ptrNumRows), sizeof(uint32_t)); // read number of rows
135 std::getline(in_file, line, '\4'); // read control-D
136 in_file.read(reinterpret_cast<char *>(ptrNumColumns), sizeof(uint32_t)); // read number of columns
137 in_file.read(reinterpret_cast<char *>(&memory.front()),
138 *ptrNumRows * *ptrNumColumns * sizeof(float)); // read array data
142 fprintf(stderr, "Failed to open %s for reading in GetKaldiArkInfo()!\n", fileName);
146 *ptrNumBytesPerElement = sizeof(float);
149 void SaveKaldiArkArray(const char *fileName,
154 uint32_t numColumns) {
155 std::ios_base::openmode mode = std::ios::binary;
157 mode |= std::ios::app;
159 std::ofstream out_file(fileName, mode);
160 if (out_file.good()) {
161 out_file.write(name.c_str(), name.length()); // write name
162 out_file.write("\0", 1);
163 out_file.write("BFM ", 4);
164 out_file.write("\4", 1);
165 out_file.write(reinterpret_cast<char *>(&numRows), sizeof(uint32_t));
166 out_file.write("\4", 1);
167 out_file.write(reinterpret_cast<char *>(&numColumns), sizeof(uint32_t));
168 out_file.write(reinterpret_cast<char *>(ptrMemory), numRows * numColumns * sizeof(float));
171 throw std::runtime_error(std::string("Failed to open %s for writing in SaveKaldiArkArray()!\n") + fileName);
175 float ScaleFactorForQuantization(void *ptrFloatMemory, float targetMax, uint32_t numElements) {
176 float *ptrFloatFeat = reinterpret_cast<float *>(ptrFloatMemory);
180 for (uint32_t i = 0; i < numElements; i++) {
181 if (fabs(ptrFloatFeat[i]) > max) {
182 max = fabs(ptrFloatFeat[i]);
189 scaleFactor = targetMax / max;
192 return (scaleFactor);
195 void ClearScoreError(score_error_t *error) {
196 error->numScores = 0;
197 error->numErrors = 0;
198 error->maxError = 0.0;
199 error->rmsError = 0.0;
200 error->sumError = 0.0;
201 error->sumRmsError = 0.0;
202 error->sumSquaredError = 0.0;
203 error->maxRelError = 0.0;
204 error->sumRelError = 0.0;
205 error->sumSquaredRelError = 0.0;
208 void UpdateScoreError(score_error_t *error, score_error_t *totalError) {
209 totalError->numErrors += error->numErrors;
210 totalError->numScores += error->numScores;
211 totalError->sumRmsError += error->rmsError;
212 totalError->sumError += error->sumError;
213 totalError->sumSquaredError += error->sumSquaredError;
214 if (error->maxError > totalError->maxError) {
215 totalError->maxError = error->maxError;
217 totalError->sumRelError += error->sumRelError;
218 totalError->sumSquaredRelError += error->sumSquaredRelError;
219 if (error->maxRelError > totalError->maxRelError) {
220 totalError->maxRelError = error->maxRelError;
224 uint32_t CompareScores(float *ptrScoreArray,
225 void *ptrRefScoreArray,
226 score_error_t *scoreError,
228 uint32_t numColumns) {
229 uint32_t numErrors = 0;
231 ClearScoreError(scoreError);
233 float *A = ptrScoreArray;
234 float *B = reinterpret_cast<float *>(ptrRefScoreArray);
235 for (uint32_t i = 0; i < numRows; i++) {
236 for (uint32_t j = 0; j < numColumns; j++) {
237 float score = A[i * numColumns + j];
238 float refscore = B[i * numColumns + j];
239 float error = fabs(refscore - score);
240 float rel_error = error / (static_cast<float>(fabs(refscore)) + 1e-20f);
241 float squared_error = error * error;
242 float squared_rel_error = rel_error * rel_error;
243 scoreError->numScores++;
244 scoreError->sumError += error;
245 scoreError->sumSquaredError += squared_error;
246 if (error > scoreError->maxError) {
247 scoreError->maxError = error;
249 scoreError->sumRelError += rel_error;
250 scoreError->sumSquaredRelError += squared_rel_error;
251 if (rel_error > scoreError->maxRelError) {
252 scoreError->maxRelError = rel_error;
254 if (error > scoreError->threshold) {
259 scoreError->rmsError = sqrt(scoreError->sumSquaredError / (numRows * numColumns));
260 scoreError->sumRmsError += scoreError->rmsError;
261 scoreError->numErrors = numErrors;
266 float StdDevError(score_error_t error) {
267 return (sqrt(error.sumSquaredError / error.numScores
268 - (error.sumError / error.numScores) * (error.sumError / error.numScores)));
271 float StdDevRelError(score_error_t error) {
272 return (sqrt(error.sumSquaredRelError / error.numScores
273 - (error.sumRelError / error.numScores) * (error.sumRelError / error.numScores)));
276 #if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
277 #if defined(_WIN32) || defined(WIN32)
286 inline void native_cpuid(unsigned int *eax, unsigned int *ebx,
287 unsigned int *ecx, unsigned int *edx) {
289 #if defined(_WIN32) || defined(WIN32)
290 int regs[4] = {static_cast<int>(*eax), static_cast<int>(*ebx), static_cast<int>(*ecx), static_cast<int>(*edx)};
291 __cpuid(regs, level);
292 *eax = static_cast<uint32_t>(regs[0]);
293 *ebx = static_cast<uint32_t>(regs[1]);
294 *ecx = static_cast<uint32_t>(regs[2]);
295 *edx = static_cast<uint32_t>(regs[3]);
297 __get_cpuid(level, eax, ebx, ecx, edx);
301 // return GNA module frequency in MHz
302 float getGnaFrequencyMHz() {
309 const uint8_t sixth_family = 6;
310 const uint8_t cannon_lake_model = 102;
311 const uint8_t gemini_lake_model = 122;
313 native_cpuid(&eax, &ebx, &ecx, &edx);
314 family = (eax >> 8) & 0xF;
316 // model is the concatenation of two fields
317 // | extended model | model |
318 // copy extended model data
319 model = (eax >> 16) & 0xF;
323 model += (eax >> 4) & 0xF;
325 if (family == sixth_family && model == cannon_lake_model) {
327 } else if (family == sixth_family &&
328 model == gemini_lake_model) {
331 // counters not supported and we retrns just default value
338 void printReferenceCompareResults(score_error_t const &totalError,
340 std::ostream &stream) {
341 stream << " max error: " <<
342 totalError.maxError << std::endl;
343 stream << " avg error: " <<
344 totalError.sumError / totalError.numScores << std::endl;
345 stream << " avg rms error: " <<
346 totalError.sumRmsError / framesNum << std::endl;
347 stream << " stdev error: " <<
348 StdDevError(totalError) << std::endl << std::endl;
352 void printPerformanceCounters(std::map<std::string,
353 InferenceEngine::InferenceEngineProfileInfo> const &utterancePerfMap,
355 std::ostream &stream, std::string fullDeviceName) {
356 #if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
357 std::ios_base::fmtflags fmt_flags(stream.flags() );
358 stream << std::endl << "Performance counts:" << std::endl;
359 stream << std::setw(10) << std::right << "" << "Counter descriptions";
360 stream << std::setw(22) << "Utt scoring time";
361 stream << std::setw(18) << "Avg infer time";
364 stream << std::setw(46) << "(ms)";
365 stream << std::setw(24) << "(us per call)";
368 for (const auto &it : utterancePerfMap) {
369 std::string const &counter_name = it.first;
370 float current_units = static_cast<float>(it.second.realTime_uSec);
371 float call_units = current_units / callsNum;
372 // if GNA HW counters
373 // get frequency of GNA module
374 float freq = getGnaFrequencyMHz();
375 current_units /= freq * 1000;
377 stream << std::setw(30) << std::left << counter_name.substr(4, counter_name.size() - 1);
378 stream << std::setw(16) << std::right << current_units;
379 stream << std::setw(21) << std::right << call_units;
383 std::cout << std::endl;
384 std::cout << "Full device name: " << fullDeviceName << std::endl;
385 std::cout << std::endl;
386 stream.flags(fmt_flags);
390 void getPerformanceCounters(InferenceEngine::InferRequest &request,
391 std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &perfCounters) {
392 auto retPerfCounters = request.GetPerformanceCounts();
394 for (const auto &pair : retPerfCounters) {
395 perfCounters[pair.first] = pair.second;
399 void sumPerformanceCounters(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> const &perfCounters,
400 std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &totalPerfCounters) {
401 for (const auto &pair : perfCounters) {
402 totalPerfCounters[pair.first].realTime_uSec += pair.second.realTime_uSec;
406 bool ParseAndCheckCommandLine(int argc, char *argv[]) {
407 // ---------------------------Parsing and validation of input args--------------------------------------
408 slog::info << "Parsing input parameters" << slog::endl;
410 gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
413 showAvailableDevices();
416 bool isDumpMode = !FLAGS_wg.empty() || !FLAGS_we.empty();
418 // input not required only in dump mode and if external scale factor provided
419 if (FLAGS_i.empty() && (!isDumpMode || FLAGS_q.compare("user") != 0)) {
421 throw std::logic_error("In model dump mode either static quantization is used (-i) or user scale"
422 " factor need to be provided. See -q user option");
424 throw std::logic_error("Input file not set. Please use -i.");
427 if (FLAGS_m.empty() && FLAGS_rg.empty()) {
428 throw std::logic_error("Either IR file (-m) or GNAModel file (-rg) need to be set.");
431 if ((!FLAGS_m.empty() && !FLAGS_rg.empty())) {
432 throw std::logic_error("Only one of -m and -rg is allowed.");
435 std::vector<std::string> supportedDevices = {
445 "HETERO:GNA_SW_EXACT,CPU",
447 "HETERO:GNA_SW_FP32,CPU",
451 if (std::find(supportedDevices.begin(), supportedDevices.end(), FLAGS_d) == supportedDevices.end()) {
452 throw std::logic_error("Specified device is not supported.");
455 float scaleFactorInput = static_cast<float>(FLAGS_sf);
456 if (scaleFactorInput <= 0.0f) {
457 throw std::logic_error("Scale factor out of range (must be non-negative).");
460 uint32_t batchSize = (uint32_t) FLAGS_bs;
461 if ((batchSize < 1) || (batchSize > 8)) {
462 throw std::logic_error("Batch size out of range (1..8).");
465 /** default is a static quantisation **/
466 if ((FLAGS_q.compare("static") != 0) && (FLAGS_q.compare("dynamic") != 0) && (FLAGS_q.compare("user") != 0)) {
467 throw std::logic_error("Quantization mode not supported (static, dynamic, user).");
470 if (FLAGS_q.compare("dynamic") == 0) {
471 throw std::logic_error("Dynamic quantization not yet supported.");
474 if (FLAGS_qb != 16 && FLAGS_qb != 8) {
475 throw std::logic_error("Only 8 or 16 bits supported.");
478 if (FLAGS_nthreads <= 0) {
479 throw std::logic_error("Invalid value for 'nthreads' argument. It must be greater that or equal to 0");
482 if (FLAGS_cw_r < 0) {
483 throw std::logic_error("Invalid value for 'cw_r' argument. It must be greater than or equal to 0");
486 if (FLAGS_cw_l < 0) {
487 throw std::logic_error("Invalid value for 'cw_l' argument. It must be greater than or equal to 0");
490 // RH decoder parameters
491 if (FLAGS_hmm.empty()) {
492 throw std::logic_error("RH HMM model file not set. Please use -hmm.");
494 if (FLAGS_labels.empty()) {
495 throw std::logic_error("RH labels file not set. Please use -labels.");
497 if (FLAGS_g.empty()) {
498 throw std::logic_error("RH LM: G.fst model file not set. Please use -g.");
500 if (FLAGS_cl.empty()) {
501 throw std::logic_error("RH LM: CL.fst model file not set. Please use -cl.");
507 uint8_t* ReadBinaryFile(const char* filename, uint32_t* size) {
508 if (nullptr == size) {
509 throw std::logic_error("Size parameter is null");
512 FILE * f = fopen(filename, "rb");
514 throw std::runtime_error("Failed to open binary file " + std::string(filename));
517 int32_t res = fseek(f, 0, SEEK_END);
520 throw std::runtime_error("Error occured while loading (fseek) file " + std::string(filename));
523 auto fileSize = ftell(f);
526 throw std::runtime_error("Error occured while loading (ftell) file " + std::string(filename));
530 res = fseek(f, 0, SEEK_SET);
531 uint8_t* data = new (std::nothrow) uint8_t[fileSize];
534 throw std::runtime_error("Not enough memory to load file " + std::string(filename));
537 *size = fread(data, 1, fileSize, f);
540 if (*size != fileSize) {
542 throw std::runtime_error("Could not read all the data from file " + std::string(filename));
548 void InitializeRhDecoder(RhDecoderInstanceParams& instanceParams, int32_t scoreVectorSize) {
549 uint32_t hmm_size = 0;
550 uint32_t cl_size = 0;
552 uint32_t label_size = 0;
554 instanceParams.hmm_data = ReadBinaryFile(FLAGS_hmm.c_str(), &hmm_size);
555 instanceParams.cl_data = ReadBinaryFile(FLAGS_cl.c_str(), &cl_size);
556 instanceParams.g_data = ReadBinaryFile(FLAGS_g.c_str(), &g_size);
557 instanceParams.label_data = ReadBinaryFile(FLAGS_labels.c_str(), &label_size);
559 if (instanceParams.hmm_data && instanceParams.cl_data &&
560 instanceParams.g_data && instanceParams.label_data) {
561 RhDecoderStatus status = RhDecoderCreateInstance(&instanceParams.handle);
564 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
565 throw std::logic_error("Failed to create decoder");
568 status = RhDecoderSetDefaultParameterValues(instanceParams.handle,
569 RhAcousticModelType::RH_ACOUSTIC_MODEL_TYPE_GENERIC_CHAIN);
570 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
571 throw std::logic_error("Failed to set default decoder values");
574 // now overwrite some of the parameters
575 float acoustic_scale_factor = static_cast<float>(FLAGS_amsf);
576 status = RhDecoderSetParameterValue(instanceParams.handle, RH_DECODER_ACOUSTIC_SCALE_FACTOR,
577 &acoustic_scale_factor, sizeof(float));
578 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
579 throw std::logic_error("Failed to set parameter acoustic_scale_factor value");
582 status = RhDecoderSetParameterValue(instanceParams.handle, RH_DECODER_ACOUSTIC_SCORE_VECTOR_SIZE,
583 &scoreVectorSize, sizeof(int));
584 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
585 throw std::logic_error("Failed to set parameter score_vector_size value");
588 float beam_width = static_cast<float>(FLAGS_beam_width);
589 status = RhDecoderSetParameterValue(instanceParams.handle, RH_DECODER_BEAM_WIDTH,
590 &beam_width, sizeof(float));
591 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
592 throw std::logic_error("Failed to set parameter beam_width value");
595 status = RhDecoderSetParameterValue(instanceParams.handle, RH_DECODER_NBEST,
596 &FLAGS_nbest, sizeof(int));
597 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
598 throw std::logic_error("Failed to set parameter nbest value");
601 status = RhDecoderSetParameterValue(instanceParams.handle, RH_DECODER_G_CACHE_LOG_SIZE,
602 &FLAGS_gcls, sizeof(int));
603 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
604 throw std::logic_error("Failed to set parameter g_cache_log_size value");
607 status = RhDecoderSetParameterValue(instanceParams.handle, RH_DECODER_TRACE_BACK_LOG_SIZE,
608 &FLAGS_tbls, sizeof(int));
609 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
610 throw std::logic_error("Failed to set parameter trace_back_log_size value");
613 status = RhDecoderSetParameterValue(instanceParams.handle, RH_DECODER_MIN_STABLE_FRAMES,
614 &FLAGS_msf, sizeof(int));
615 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
616 throw std::logic_error("Failed to set parameter min_stable_frames value");
619 status = RhDecoderSetParameterValue(instanceParams.handle, RH_DECODER_TOKEN_BUFFER_SIZE,
620 &FLAGS_tbs, sizeof(int));
621 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
622 throw std::logic_error("Failed to set parameter token_buffer_size value");
625 status = RhDecoderSetupResource(instanceParams.handle,
626 RhResourceType::HMM, instanceParams.hmm_data, hmm_size);
627 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
628 throw std::logic_error("Failed to load HMM model");
631 status = RhDecoderSetupResource(instanceParams.handle,
632 RhResourceType::PRONUNCIATION_MODEL, instanceParams.cl_data, cl_size);
633 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
634 throw std::logic_error("Failed to load pronunciation model");
637 status = RhDecoderSetupResource(instanceParams.handle,
638 RhResourceType::LANGUAGE_MODEL, instanceParams.g_data, g_size);
639 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
640 throw std::logic_error("Failed to load language model");
643 status = RhDecoderSetupResource(instanceParams.handle,
644 RhResourceType::LABELS, instanceParams.label_data, label_size);
645 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
646 throw std::logic_error("Failed to load labels");
649 status = RhDecoderInitInstance(instanceParams.handle);
650 if (RhDecoderStatus::RH_DECODER_SUCCESS != status) {
651 throw std::logic_error("Failed to initialize decoder");
655 throw std::logic_error("Failed to read one of the resources");
659 void FreeRhDecoder(RhDecoderInstanceParams& instanceParams) {
660 if (instanceParams.handle) {
661 RhDecoderStatus status = RhDecoderFreeInstance(instanceParams.handle);
662 if (status != RH_DECODER_SUCCESS) {
663 slog::err << "Failed to free decoder. Status: " << status << slog::endl;
664 throw std::logic_error("Failed to free decoder. Status: " + std::to_string(status));
668 if (instanceParams.hmm_data) {
669 delete[] instanceParams.hmm_data;
670 instanceParams.hmm_data = nullptr;
673 if (instanceParams.cl_data) {
674 delete[] instanceParams.cl_data;
675 instanceParams.cl_data = nullptr;
678 if (instanceParams.g_data) {
679 delete[] instanceParams.g_data;
680 instanceParams.g_data = nullptr;
683 if (instanceParams.label_data) {
684 delete[] instanceParams.label_data;
685 instanceParams.label_data = nullptr;
690 * @brief The entry point for inference engine automatic speech recognition sample
691 * @file speech_sample/main.cpp
692 * @example speech_sample/main.cpp
694 int main(int argc, char *argv[]) {
696 slog::info << "InferenceEngine: " << GetInferenceEngineVersion() << slog::endl;
698 // ------------------------------ Parsing and validation of input args ---------------------------------
699 if (!ParseAndCheckCommandLine(argc, argv)) {
703 if (FLAGS_l.empty()) {
704 slog::info << "No extensions provided" << slog::endl;
707 auto isFeature = [&](const std::string xFeature) { return FLAGS_d.find(xFeature) != std::string::npos; };
709 bool useGna = isFeature("GNA");
710 bool useHetero = isFeature("HETERO");
711 std::string deviceStr =
712 useHetero && useGna ? "HETERO:GNA,CPU" : FLAGS_d.substr(0, (FLAGS_d.find("_")));
713 float scaleFactorInput = static_cast<float>(FLAGS_sf);
714 uint32_t batchSize = (FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : (uint32_t) FLAGS_bs;
716 std::vector<std::string> inputArkFiles;
717 std::vector<uint32_t> numBytesThisUtterance;
718 uint32_t numUtterances(0);
719 if (!FLAGS_i.empty()) {
721 std::istringstream stream(FLAGS_i);
723 uint32_t currentNumUtterances(0), currentNumBytesThisUtterance(0);
724 while (getline(stream, outStr, ',')) {
725 std::string filename(fileNameNoExt(outStr) + ".ark");
726 inputArkFiles.push_back(filename);
728 GetKaldiArkInfo(filename.c_str(), 0, ¤tNumUtterances, ¤tNumBytesThisUtterance);
729 if (numUtterances == 0) {
730 numUtterances = currentNumUtterances;
731 } else if (currentNumUtterances != numUtterances) {
732 throw std::logic_error("Incorrect input files. Number of utterance must be the same for all ark files");
734 numBytesThisUtterance.push_back(currentNumBytesThisUtterance);
737 size_t numInputArkFiles(inputArkFiles.size());
738 // -----------------------------------------------------------------------------------------------------
740 // --------------------------- 1. Load inference engine -------------------------------------
741 slog::info << "Loading Inference Engine" << slog::endl;
744 /** Printing device version **/
745 slog::info << "Device info: " << slog::endl;
746 std::cout << ie.GetVersions(deviceStr) << std::endl;
747 // -----------------------------------------------------------------------------------------------------
749 // --------------------------- 2. Read IR Generated by ModelOptimizer (.xml and .bin files) ------------
750 slog::info << "Loading network files" << slog::endl;
752 CNNNetReader netBuilder;
753 if (!FLAGS_m.empty()) {
754 /** Read network model **/
755 netBuilder.ReadNetwork(FLAGS_m);
757 /** Extract model name and load weights **/
758 std::string binFileName = fileNameNoExt(FLAGS_m) + ".bin";
759 netBuilder.ReadWeights(binFileName);
761 // -------------------------------------------------------------------------------------------------
763 // --------------------------- 3. Set batch size ---------------------------------------------------
764 /** Set batch size. Unlike in imaging, batching in time (rather than space) is done for speech recognition. **/
765 netBuilder.getNetwork().setBatchSize(batchSize);
766 slog::info << "Batch size is " << std::to_string(netBuilder.getNetwork().getBatchSize())
770 /** Setting parameter for per layer metrics **/
771 std::map<std::string, std::string> gnaPluginConfig;
772 std::map<std::string, std::string> genericPluginConfig;
774 std::string gnaDevice =
775 useHetero ? FLAGS_d.substr(FLAGS_d.find("GNA"), FLAGS_d.find(",") - FLAGS_d.find("GNA")) : FLAGS_d;
776 gnaPluginConfig[GNAConfigParams::KEY_GNA_DEVICE_MODE] =
777 gnaDevice.find("_") == std::string::npos ? "GNA_AUTO" : gnaDevice;
781 genericPluginConfig[PluginConfigParams::KEY_PERF_COUNT] = PluginConfigParams::YES;
784 if (FLAGS_q.compare("user") == 0) {
785 if (numInputArkFiles > 1) {
786 std::string errMessage("Incorrect use case for multiple input ark files. Please don't use -q 'user' for this case.");
787 throw std::logic_error(errMessage);
789 slog::info << "Using scale factor of " << FLAGS_sf << slog::endl;
790 gnaPluginConfig[GNA_CONFIG_KEY(SCALE_FACTOR)] = std::to_string(FLAGS_sf);
792 // "static" quantization with calculated scale factor
793 for (size_t i = 0; i < numInputArkFiles; i++) {
794 auto inputArkName = inputArkFiles[i].c_str();
796 std::vector<uint8_t> ptrFeatures;
797 uint32_t numArrays(0), numBytes(0), numFrames(0), numFrameElements(0), numBytesPerElement(0);
798 GetKaldiArkInfo(inputArkName, 0, &numArrays, &numBytes);
799 ptrFeatures.resize(numBytes);
800 LoadKaldiArkArray(inputArkName,
806 &numBytesPerElement);
808 ScaleFactorForQuantization(ptrFeatures.data(), MAX_VAL_2B_FEAT, numFrames * numFrameElements);
809 slog::info << "Using scale factor of " << scaleFactorInput << " calculated from first utterance."
811 std::string scaleFactorConfigKey = GNA_CONFIG_KEY(SCALE_FACTOR) + std::string("_") + std::to_string(i);
812 gnaPluginConfig[scaleFactorConfigKey] = std::to_string(scaleFactorInput);
817 gnaPluginConfig[GNAConfigParams::KEY_GNA_PRECISION] = "I8";
819 gnaPluginConfig[GNAConfigParams::KEY_GNA_PRECISION] = "I16";
822 gnaPluginConfig[GNAConfigParams::KEY_GNA_LIB_N_THREADS] = std::to_string((FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : FLAGS_nthreads);
823 gnaPluginConfig[GNA_CONFIG_KEY(COMPACT_MODE)] = CONFIG_VALUE(NO);
824 // -----------------------------------------------------------------------------------------------------
826 // --------------------------- 4. Write model to file --------------------------------------------------
827 // Embedded GNA model dumping (for Intel(R) Speech Enabling Developer Kit)
828 if (!FLAGS_we.empty()) {
829 gnaPluginConfig[GNAConfigParams::KEY_GNA_FIRMWARE_MODEL_IMAGE] = FLAGS_we;
831 // -----------------------------------------------------------------------------------------------------
833 // --------------------------- 5. Loading model to the device ------------------------------------------
836 genericPluginConfig.insert(std::begin(gnaPluginConfig), std::end(gnaPluginConfig));
838 auto t0 = Time::now();
839 ExecutableNetwork executableNet;
841 if (!FLAGS_m.empty()) {
842 slog::info << "Loading model to the device" << slog::endl;
843 executableNet = ie.LoadNetwork(netBuilder.getNetwork(), deviceStr, genericPluginConfig);
845 slog::info << "Importing model to the device" << slog::endl;
846 executableNet = ie.ImportNetwork(FLAGS_rg.c_str(), deviceStr, genericPluginConfig);
849 ms loadTime = std::chrono::duration_cast<ms>(Time::now() - t0);
850 slog::info << "Model loading time " << loadTime.count() << " ms" << slog::endl;
852 // --------------------------- 6. Exporting gna model using InferenceEngine AOT API---------------------
853 if (!FLAGS_wg.empty()) {
854 slog::info << "Writing GNA Model to file " << FLAGS_wg << slog::endl;
856 executableNet.Export(FLAGS_wg);
857 ms exportTime = std::chrono::duration_cast<ms>(Time::now() - t0);
858 slog::info << "Exporting time " << exportTime.count() << " ms" << slog::endl;
862 if (!FLAGS_we.empty()) {
863 slog::info << "Exported GNA embedded model to file " << FLAGS_we << slog::endl;
867 std::vector<InferRequestStruct> inferRequests((FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : FLAGS_nthreads);
868 for (auto& inferRequest : inferRequests) {
869 inferRequest = {executableNet.CreateInferRequest(), -1, batchSize};
871 // -----------------------------------------------------------------------------------------------------
873 // --------------------------- 7. Prepare input blobs --------------------------------------------------
874 /** Taking information about all topology inputs **/
875 ConstInputsDataMap cInputInfo = executableNet.GetInputsInfo();
876 /** Stores all input blobs data **/
877 if (cInputInfo.size() != numInputArkFiles) {
878 throw std::logic_error("Number of network inputs("
879 + std::to_string(cInputInfo.size()) + ") is not equal to number of ark files("
880 + std::to_string(numInputArkFiles) + ")");
883 std::vector<Blob::Ptr> ptrInputBlobs;
884 for (auto& input : cInputInfo) {
885 ptrInputBlobs.push_back(inferRequests.begin()->inferRequest.GetBlob(input.first));
888 InputsDataMap inputInfo;
889 if (!FLAGS_m.empty()) {
890 inputInfo = netBuilder.getNetwork().getInputsInfo();
892 /** configure input precision if model loaded from IR **/
893 for (auto &item : inputInfo) {
894 Precision inputPrecision = Precision::FP32; // specify Precision::I16 to provide quantized inputs
895 item.second->setPrecision(inputPrecision);
896 item.second->getInputData()->setLayout(Layout::NC); // row major layout
899 // -----------------------------------------------------------------------------------------------------
901 // --------------------------- 8. Prepare output blobs -------------------------------------------------
902 ConstOutputsDataMap cOutputInfo(executableNet.GetOutputsInfo());
903 OutputsDataMap outputInfo;
904 if (!FLAGS_m.empty()) {
905 outputInfo = netBuilder.getNetwork().getOutputsInfo();
908 Blob::Ptr ptrOutputBlob = inferRequests.begin()->inferRequest.GetBlob(cOutputInfo.rbegin()->first);
910 for (auto &item : outputInfo) {
911 DataPtr outData = item.second;
913 throw std::logic_error("output data pointer is not valid");
916 Precision outputPrecision = Precision::FP32; // specify Precision::I32 to retrieve quantized outputs
917 outData->setPrecision(outputPrecision);
918 outData->setLayout(Layout::NC); // row major layout
920 // -----------------------------------------------------------------------------------------------------
922 // --------------------------- 9. Initialize RH decoder ------------------------------------------------
924 RhDecoderInstanceParams rhDecoderInstanceParams{ nullptr };
925 auto lastLayerOutputCount = outputInfo.begin()->second->getDims()[1];
926 InitializeRhDecoder(rhDecoderInstanceParams, lastLayerOutputCount);
928 // allocate 1MB for result
929 std::vector<char> rh_utterance_transcription(1024 * 1024);
931 // -----------------------------------------------------------------------------------------------------
933 // --------------------------- 10. Do inference --------------------------------------------------------
935 std::vector<std::vector<uint8_t>> ptrUtterances;
936 std::vector<uint8_t> ptrScores;
937 std::vector<uint8_t> ptrReferenceScores;
938 score_error_t frameError, totalError;
940 ptrUtterances.resize(inputArkFiles.size());
941 for (uint32_t utteranceIndex = 0; utteranceIndex < numUtterances; ++utteranceIndex) {
942 std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> utterancePerfMap;
944 uint32_t numFrames(0), n(0);
945 std::vector<uint32_t> numFrameElementsInput;
947 uint32_t numFramesReference(0), numFrameElementsReference(0), numBytesPerElementReference(0),
948 numBytesReferenceScoreThisUtterance(0);
949 const uint32_t numScoresPerFrame = ptrOutputBlob->size() / batchSize;
951 numFrameElementsInput.resize(numInputArkFiles);
952 for (size_t i = 0; i < inputArkFiles.size(); i++) {
953 std::vector<uint8_t> ptrUtterance;
954 auto inputArkFilename = inputArkFiles[i].c_str();
955 uint32_t currentNumFrames(0), currentNumFrameElementsInput(0), currentNumBytesPerElementInput(0);
956 GetKaldiArkInfo(inputArkFilename, utteranceIndex, &n, &numBytesThisUtterance[i]);
957 ptrUtterance.resize(numBytesThisUtterance[i]);
958 LoadKaldiArkArray(inputArkFilename,
963 ¤tNumFrameElementsInput,
964 ¤tNumBytesPerElementInput);
965 if (numFrames == 0) {
966 numFrames = currentNumFrames;
967 } else if (numFrames != currentNumFrames) {
968 std::string errMessage("Number of frames in ark files is different: " + std::to_string(numFrames) +
969 " and " + std::to_string(currentNumFrames));
970 throw std::logic_error(errMessage);
973 ptrUtterances[i] = ptrUtterance;
974 numFrameElementsInput[i] = currentNumFrameElementsInput;
978 for (auto& ptrInputBlob : ptrInputBlobs) {
979 if (ptrInputBlob->size() != numFrameElementsInput[i++] * batchSize) {
980 throw std::logic_error("network input size(" + std::to_string(ptrInputBlob->size()) +
981 ") mismatch to ark file size (" +
982 std::to_string(numFrameElementsInput[i-1] * batchSize) + ")");
986 ptrScores.resize(numFrames * numScoresPerFrame * sizeof(float));
987 if (!FLAGS_r.empty()) {
988 std::string refUtteranceName;
989 GetKaldiArkInfo(FLAGS_r.c_str(), utteranceIndex, &n, &numBytesReferenceScoreThisUtterance);
990 ptrReferenceScores.resize(numBytesReferenceScoreThisUtterance);
991 LoadKaldiArkArray(FLAGS_r.c_str(),
996 &numFrameElementsReference,
997 &numBytesPerElementReference);
1000 double totalTime = 0.0;
1002 std::cout << "Utterance " << utteranceIndex << ": " << std::endl;
1004 ClearScoreError(&totalError);
1005 totalError.threshold = frameError.threshold = MAX_SCORE_DIFFERENCE;
1006 auto outputFrame = &ptrScores.front();
1007 std::vector<uint8_t*> inputFrame;
1008 for (auto& ut : ptrUtterances) {
1009 inputFrame.push_back(&ut.front());
1012 std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> callPerfMap;
1014 size_t frameIndex = 0;
1015 uint32_t numFramesArkFile = numFrames;
1016 numFrames += FLAGS_cw_l + FLAGS_cw_r;
1017 uint32_t numFramesThisBatch{batchSize};
1019 auto t0 = Time::now();
1022 while (frameIndex <= numFrames) {
1023 if (frameIndex == numFrames) {
1024 if (std::find_if(inferRequests.begin(),
1025 inferRequests.end(),
1026 [&](InferRequestStruct x) { return (x.frameIndex != -1); } ) == inferRequests.end()) {
1031 bool inferRequestFetched = false;
1032 for (auto &inferRequest : inferRequests) {
1033 if (frameIndex == numFrames) {
1034 numFramesThisBatch = 1;
1036 numFramesThisBatch = (numFrames - frameIndex < batchSize) ? (numFrames - frameIndex)
1040 if (inferRequest.frameIndex != -1) {
1041 StatusCode code = inferRequest.inferRequest.Wait(
1042 InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
1044 if (code != StatusCode::OK) {
1045 if (!useHetero) continue;
1046 if (code != StatusCode::INFER_NOT_STARTED) continue;
1049 if (inferRequest.frameIndex >= 0) {
1050 Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(cOutputInfo.rbegin()->first);
1052 const float* acoustic_score_vector_index = outputBlob->buffer();
1054 for (uint32_t f = 0; f < inferRequest.numFramesThisBatch; ++f) {
1055 RhDecoderStatus rh_status = RhDecoderProcessFrame(rhDecoderInstanceParams.handle,
1056 acoustic_score_vector_index, numScoresPerFrame, &info);
1057 if (RhDecoderStatus::RH_DECODER_SUCCESS != rh_status) {
1058 throw std::logic_error(
1059 "Decoder failed to process frame: " + std::to_string(inferRequest.frameIndex));
1061 if (info.is_result_stable || inferRequest.frameIndex + f == numFrames - 1) {
1062 RhDecoderGetResult(rhDecoderInstanceParams.handle,
1063 RhDecoderResultType::RH_DECODER_FINAL_RESULT,
1064 rh_utterance_transcription.data(),
1065 rh_utterance_transcription.size());
1066 if (RhDecoderStatus::RH_DECODER_SUCCESS != rh_status) {
1067 throw std::logic_error("Failed to retrieve speech recognition result");
1070 std::cout << uttName << "\t" << rh_utterance_transcription.data() << std::endl;
1073 acoustic_score_vector_index += lastLayerOutputCount;
1076 if (!FLAGS_o.empty()) {
1078 &ptrScores.front() + numScoresPerFrame * sizeof(float) * (inferRequest.frameIndex);
1079 Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(cOutputInfo.rbegin()->first);
1080 auto byteSize = inferRequest.numFramesThisBatch * numScoresPerFrame * sizeof(float);
1081 std::memcpy(outputFrame,
1082 outputBlob->buffer(),
1086 if (!FLAGS_r.empty()) {
1087 Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(cOutputInfo.begin()->first);
1088 CompareScores(outputBlob->buffer().as<float*>(),
1089 &ptrReferenceScores[inferRequest.frameIndex *
1090 numFrameElementsReference *
1091 numBytesPerElementReference],
1093 inferRequest.numFramesThisBatch,
1094 numFrameElementsReference);
1095 UpdateScoreError(&frameError, &totalError);
1098 // retrive new counters
1099 getPerformanceCounters(inferRequest.inferRequest, callPerfMap);
1100 // summarize retrived counters with all previous
1101 sumPerformanceCounters(callPerfMap, utterancePerfMap);
1106 if (frameIndex == numFrames) {
1107 inferRequest.frameIndex = -1;
1111 ptrInputBlobs.clear();
1112 for (auto& input : cInputInfo) {
1113 ptrInputBlobs.push_back(inferRequest.inferRequest.GetBlob(input.first));
1116 for (size_t i = 0; i < numInputArkFiles; ++i) {
1117 std::memcpy(ptrInputBlobs[i]->buffer(),
1119 ptrInputBlobs[i]->byteSize());
1122 int index = static_cast<int>(frameIndex) - (FLAGS_cw_l + FLAGS_cw_r);
1123 inferRequest.inferRequest.StartAsync();
1124 inferRequest.frameIndex = index < 0 ? -2 : index;
1125 inferRequest.numFramesThisBatch = numFramesThisBatch;
1127 frameIndex += numFramesThisBatch;
1128 for (size_t j = 0; j < inputArkFiles.size(); j++) {
1129 if (FLAGS_cw_l > 0 || FLAGS_cw_r > 0) {
1130 int i = frameIndex - FLAGS_cw_l;
1131 if (i > 0 && i < static_cast<int>(numFramesArkFile)) {
1132 inputFrame[j] += sizeof(float) * numFrameElementsInput[j] * numFramesThisBatch;
1133 } else if (i >= static_cast<int>(numFramesArkFile)) {
1134 inputFrame[j] = &ptrUtterances[0].front() +
1135 (numFramesArkFile - 1) * sizeof(float) * numFrameElementsInput[j] * numFramesThisBatch;
1137 inputFrame[j] = &ptrUtterances[0].front();
1140 inputFrame[j] += sizeof(float) * numFrameElementsInput[j] * numFramesThisBatch;
1143 inferRequestFetched |= true;
1146 if (!inferRequestFetched) {
1147 std::this_thread::sleep_for(std::chrono::milliseconds(1));
1154 ms d = std::chrono::duration_cast<ms>(fs);
1155 totalTime += d.count();
1157 // resetting state between utterances
1158 for (auto &&state : executableNet.QueryState()) {
1162 if (!FLAGS_o.empty()) {
1163 bool shouldAppend = (utteranceIndex == 0) ? false : true;
1164 SaveKaldiArkArray(FLAGS_o.c_str(), shouldAppend, uttName, &ptrScores.front(),
1165 numFrames, numScoresPerFrame);
1168 /** Show performance results **/
1169 std::cout << "Total time in Infer (HW and SW):\t" << totalTime << " ms"
1171 std::cout << "Frames in utterance:\t\t\t" << numFrames << " frames"
1173 std::cout << "Average Infer time per frame:\t\t" << totalTime / static_cast<double>(numFrames) << " ms"
1177 printPerformanceCounters(utterancePerfMap, frameIndex, std::cout, getFullDeviceName(ie, FLAGS_d));
1179 if (!FLAGS_r.empty()) {
1180 printReferenceCompareResults(totalError, numFrames, std::cout);
1182 std::cout << "End of Utterance " << utteranceIndex << std::endl << std::endl;
1185 FreeRhDecoder(rhDecoderInstanceParams);
1186 // -----------------------------------------------------------------------------------------------------
1188 catch (const std::exception &error) {
1189 slog::err << error.what() << slog::endl;
1193 slog::err << "Unknown/internal exception happened" << slog::endl;
1197 slog::info << "Execution successful" << slog::endl;