[GNA][Speech sample] Add option to specify blob names (#1529)
authorAndrey Dmitriev <andrey.dmitriev@intel.com>
Fri, 16 Oct 2020 12:34:22 +0000 (15:34 +0300)
committerGitHub <noreply@github.com>
Fri, 16 Oct 2020 12:34:22 +0000 (15:34 +0300)
* Added output names

* Add input, output, ref names

* Added zero scale factor

* Adding support for multiple reference files

inference-engine/samples/speech_sample/main.cpp
inference-engine/samples/speech_sample/speech_sample.hpp

index ef36d91..5bae50b 100644 (file)
@@ -426,6 +426,20 @@ std::vector<std::string> ParseScaleFactors(const std::string& str) {
     return scaleFactorInput;
 }
 
+std::vector<std::string> ParseBlobName(std::string str) {
+    std::vector<std::string> blobName;
+    if (!str.empty()) {
+        size_t pos_last = 0;
+        size_t pos_next = 0;
+        while ((pos_next = str.find(",", pos_last)) != std::string::npos) {
+            blobName.push_back(str.substr(pos_last, pos_next));
+            pos_last = pos_next + 1;
+        }
+        blobName.push_back(str.substr(pos_last));
+    }
+    return blobName;
+}
+
 bool ParseAndCheckCommandLine(int argc, char *argv[]) {
     // ---------------------------Parsing and validation of input args--------------------------------------
     slog::info << "Parsing input parameters" << slog::endl;
@@ -673,8 +687,30 @@ int main(int argc, char *argv[]) {
             genericPluginConfig.insert(std::begin(gnaPluginConfig), std::end(gnaPluginConfig));
         }
         auto t0 = Time::now();
+        std::vector<std::string> outputs;
         ExecutableNetwork executableNet;
 
+        if (!FLAGS_oname.empty()) {
+            std::vector<std::string> output_names = ParseBlobName(FLAGS_oname);
+            std::vector<size_t> ports;
+            for (const auto& outBlobName : output_names) {
+                int pos_layer = outBlobName.rfind(":");
+                if (pos_layer == -1) {
+                    throw std::logic_error(std::string("Output ") + std::string(outBlobName)
+                    + std::string(" doesn't have a port"));
+                }
+                outputs.push_back(outBlobName.substr(0, pos_layer));
+                try {
+                    ports.push_back(std::stoi(outBlobName.substr(pos_layer + 1)));
+                } catch (std::exception) {
+                    throw std::logic_error("Ports should have integer type");
+                }
+            }
+
+            for (size_t i = 0; i < outputs.size(); i++) {
+                network.addOutput(outputs[i], ports[i]);
+            }
+        }
         if (!FLAGS_m.empty()) {
             slog::info << "Loading model to the device" << slog::endl;
             executableNet = ie.LoadNetwork(network, deviceStr, genericPluginConfig);
@@ -682,7 +718,6 @@ int main(int argc, char *argv[]) {
             slog::info << "Importing model to the device" << slog::endl;
             executableNet = ie.ImportNetwork(FLAGS_rg.c_str(), deviceStr, genericPluginConfig);
         }
-
         ms loadTime = std::chrono::duration_cast<ms>(Time::now() - t0);
         slog::info << "Model loading time " << loadTime.count() << " ms" << slog::endl;
 
@@ -717,10 +752,27 @@ int main(int argc, char *argv[]) {
 
         /** Stores all input blobs data **/
         std::vector<Blob::Ptr> ptrInputBlobs;
-        for (auto& input : cInputInfo) {
-            ptrInputBlobs.push_back(inferRequests.begin()->inferRequest.GetBlob(input.first));
+        if (!FLAGS_iname.empty()) {
+            std::vector<std::string> inputNameBlobs = ParseBlobName(FLAGS_iname);
+            if (inputNameBlobs.size() != cInputInfo.size()) {
+                std::string errMessage(std::string("Number of network inputs ( ") + std::to_string(cInputInfo.size()) +
+                                       " ) is not equal to the number of inputs entered in the -iname argument ( " +
+                                       std::to_string(inputNameBlobs.size()) + " ).");
+                throw std::logic_error(errMessage);
+            }
+            for (const auto& input : inputNameBlobs) {
+                Blob::Ptr blob = inferRequests.begin()->inferRequest.GetBlob(input);
+                if (!blob) {
+                    std::string errMessage("No blob with name : " + input);
+                    throw std::logic_error(errMessage);
+                }
+                ptrInputBlobs.push_back(blob);
+            }
+        } else {
+            for (const auto& input : cInputInfo) {
+                ptrInputBlobs.push_back(inferRequests.begin()->inferRequest.GetBlob(input.first));
+            }
         }
-
         InputsDataMap inputInfo;
         if (!FLAGS_m.empty()) {
             inputInfo = network.getInputsInfo();
@@ -739,8 +791,21 @@ int main(int argc, char *argv[]) {
         if (!FLAGS_m.empty()) {
             outputInfo = network.getOutputsInfo();
         }
-
-        Blob::Ptr ptrOutputBlob = inferRequests.begin()->inferRequest.GetBlob(cOutputInfo.rbegin()->first);
+        std::vector<Blob::Ptr> ptrOutputBlob;
+        if (!outputs.empty()) {
+            for (const auto& output : outputs) {
+                Blob::Ptr blob = inferRequests.begin()->inferRequest.GetBlob(output);
+                if (!blob) {
+                    std::string errMessage("No blob with name : " + output);
+                    throw std::logic_error(errMessage);
+                }
+                ptrOutputBlob.push_back(blob);
+            }
+        } else {
+            for (auto& output : cOutputInfo) {
+                ptrOutputBlob.push_back(inferRequests.begin()->inferRequest.GetBlob(output.first));
+            }
+        }
 
         for (auto &item : outputInfo) {
             DataPtr outData = item.second;
@@ -754,255 +819,290 @@ int main(int argc, char *argv[]) {
         // -----------------------------------------------------------------------------------------------------
 
         // --------------------------- 10. Do inference --------------------------------------------------------
-        std::vector<std::vector<uint8_t>> ptrUtterances;
-        std::vector<uint8_t> ptrScores;
-        std::vector<uint8_t> ptrReferenceScores;
-        score_error_t frameError, totalError;
-
-        ptrUtterances.resize(inputArkFiles.size());
-
-        // initialize memory state before starting
-        for (auto &&state : executableNet.QueryState()) {
-            state.Reset();
-        }
-
-        for (uint32_t utteranceIndex = 0; utteranceIndex < numUtterances; ++utteranceIndex) {
-            std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> utterancePerfMap;
-            std::string uttName;
-            uint32_t numFrames(0), n(0);
-            std::vector<uint32_t> numFrameElementsInput;
-
-            uint32_t numFramesReference(0), numFrameElementsReference(0), numBytesPerElementReference(0),
-                    numBytesReferenceScoreThisUtterance(0);
-            const uint32_t numScoresPerFrame = ptrOutputBlob->size() / batchSize;
-
-            numFrameElementsInput.resize(numInputArkFiles);
-            for (size_t i = 0; i < inputArkFiles.size(); i++) {
-                std::vector<uint8_t> ptrUtterance;
-                auto inputArkFilename = inputArkFiles[i].c_str();
-                uint32_t currentNumFrames(0), currentNumFrameElementsInput(0), currentNumBytesPerElementInput(0);
-                GetKaldiArkInfo(inputArkFilename, utteranceIndex, &n, &numBytesThisUtterance[i]);
-                ptrUtterance.resize(numBytesThisUtterance[i]);
-                LoadKaldiArkArray(inputArkFilename,
-                                  utteranceIndex,
-                                  uttName,
-                                  ptrUtterance,
-                                  &currentNumFrames,
-                                  &currentNumFrameElementsInput,
-                                  &currentNumBytesPerElementInput);
-                if (numFrames == 0) {
-                    numFrames = currentNumFrames;
-                } else if (numFrames != currentNumFrames) {
-                    std::string errMessage("Number of frames in ark files is different: " + std::to_string(numFrames) +
-                                           " and " + std::to_string(currentNumFrames));
-                    throw std::logic_error(errMessage);
-                }
-
-                ptrUtterances[i] = ptrUtterance;
-                numFrameElementsInput[i] = currentNumFrameElementsInput;
+        std::vector<std::string> output_name_files;
+        std::vector<std::string> reference_name_files;
+        size_t count_file = 1;
+        if (!FLAGS_o.empty()) {
+            output_name_files = ParseBlobName(FLAGS_o);
+            if (output_name_files.size() != outputs.size() && !outputs.empty()) {
+                throw std::logic_error("The number of output files is not equal to the number of network outputs.");
             }
-
-            int i = 0;
-            for (auto& ptrInputBlob : ptrInputBlobs) {
-                if (ptrInputBlob->size() != numFrameElementsInput[i++] * batchSize) {
-                    throw std::logic_error("network input size(" + std::to_string(ptrInputBlob->size()) +
-                                           ") mismatch to ark file size (" +
-                                           std::to_string(numFrameElementsInput[i-1] * batchSize) + ")");
-                }
+            count_file = output_name_files.empty() ? 1 : output_name_files.size();
+        }
+        if (!FLAGS_r.empty()) {
+            reference_name_files = ParseBlobName(FLAGS_r);
+            if (reference_name_files.size() != outputs.size() && !outputs.empty()) {
+                throw std::logic_error("The number of reference files is not equal to the number of network outputs.");
             }
+            count_file = reference_name_files.empty() ? 1 : reference_name_files.size();
+        }
+        for (size_t next_output = 0; next_output < count_file; next_output++) {
+            std::vector<std::vector<uint8_t>> ptrUtterances;
+            std::vector<uint8_t> ptrScores;
+            std::vector<uint8_t> ptrReferenceScores;
+            score_error_t frameError, totalError;
+
+            ptrUtterances.resize(inputArkFiles.size());
 
-            ptrScores.resize(numFrames * numScoresPerFrame * sizeof(float));
-            if (!FLAGS_r.empty()) {
-                std::string refUtteranceName;
-                GetKaldiArkInfo(FLAGS_r.c_str(), utteranceIndex, &n, &numBytesReferenceScoreThisUtterance);
-                ptrReferenceScores.resize(numBytesReferenceScoreThisUtterance);
-                LoadKaldiArkArray(FLAGS_r.c_str(),
-                                  utteranceIndex,
-                                  refUtteranceName,
-                                  ptrReferenceScores,
-                                  &numFramesReference,
-                                  &numFrameElementsReference,
-                                  &numBytesPerElementReference);
+            // initialize memory state before starting
+            for (auto &&state : executableNet.QueryState()) {
+                state.Reset();
             }
 
-            double totalTime = 0.0;
+            for (uint32_t utteranceIndex = 0; utteranceIndex < numUtterances; ++utteranceIndex) {
+                std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> utterancePerfMap;
+                std::string uttName;
+                uint32_t numFrames(0), n(0);
+                std::vector<uint32_t> numFrameElementsInput;
+
+                uint32_t numFramesReference(0), numFrameElementsReference(0), numBytesPerElementReference(0),
+                        numBytesReferenceScoreThisUtterance(0);
+                const uint32_t numScoresPerFrame = ptrOutputBlob.size() / batchSize;
+
+                numFrameElementsInput.resize(numInputArkFiles);
+                for (size_t i = 0; i < inputArkFiles.size(); i++) {
+                    std::vector<uint8_t> ptrUtterance;
+                    auto inputArkFilename = inputArkFiles[i].c_str();
+                    uint32_t currentNumFrames(0), currentNumFrameElementsInput(0), currentNumBytesPerElementInput(0);
+                    GetKaldiArkInfo(inputArkFilename, utteranceIndex, &n, &numBytesThisUtterance[i]);
+                    ptrUtterance.resize(numBytesThisUtterance[i]);
+                    LoadKaldiArkArray(inputArkFilename,
+                                      utteranceIndex,
+                                      uttName,
+                                      ptrUtterance,
+                                      &currentNumFrames,
+                                      &currentNumFrameElementsInput,
+                                      &currentNumBytesPerElementInput);
+                    if (numFrames == 0) {
+                        numFrames = currentNumFrames;
+                    } else if (numFrames != currentNumFrames) {
+                        std::string errMessage(
+                                "Number of frames in ark files is different: " + std::to_string(numFrames) +
+                                " and " + std::to_string(currentNumFrames));
+                        throw std::logic_error(errMessage);
+                    }
 
-            std::cout << "Utterance " << utteranceIndex << ": " << std::endl;
+                    ptrUtterances[i] = ptrUtterance;
+                    numFrameElementsInput[i] = currentNumFrameElementsInput;
+                }
 
-            ClearScoreError(&totalError);
-            totalError.threshold = frameError.threshold = MAX_SCORE_DIFFERENCE;
-            auto outputFrame = &ptrScores.front();
-            std::vector<uint8_t*> inputFrame;
-            for (auto& ut : ptrUtterances) {
-                inputFrame.push_back(&ut.front());
-            }
+                int i = 0;
+                for (auto &ptrInputBlob : ptrInputBlobs) {
+                    if (ptrInputBlob->size() != numFrameElementsInput[i++] * batchSize) {
+                        throw std::logic_error("network input size(" + std::to_string(ptrInputBlob->size()) +
+                                               ") mismatch to ark file size (" +
+                                               std::to_string(numFrameElementsInput[i - 1] * batchSize) + ")");
+                    }
+                }
 
-            std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> callPerfMap;
+                ptrScores.resize(numFrames * numScoresPerFrame * sizeof(float));
+                if (!FLAGS_r.empty()) {
+                    std::string refUtteranceName;
+                    GetKaldiArkInfo(reference_name_files[next_output].c_str(), utteranceIndex, &n, &numBytesReferenceScoreThisUtterance);
+                    ptrReferenceScores.resize(numBytesReferenceScoreThisUtterance);
+                    LoadKaldiArkArray(reference_name_files[next_output].c_str(),
+                                      utteranceIndex,
+                                      refUtteranceName,
+                                      ptrReferenceScores,
+                                      &numFramesReference,
+                                      &numFrameElementsReference,
+                                      &numBytesPerElementReference);
+                }
 
-            size_t frameIndex = 0;
-            uint32_t numFramesArkFile = numFrames;
-            numFrames += FLAGS_cw_l + FLAGS_cw_r;
-            uint32_t numFramesThisBatch{batchSize};
+                double totalTime = 0.0;
 
-            auto t0 = Time::now();
-            auto t1 = t0;
+                std::cout << "Utterance " << utteranceIndex << ": " << std::endl;
 
-            while (frameIndex <= numFrames) {
-                if (frameIndex == numFrames) {
-                    if (std::find_if(inferRequests.begin(),
-                            inferRequests.end(),
-                            [&](InferRequestStruct x) { return (x.frameIndex != -1); } ) == inferRequests.end()) {
-                        break;
-                    }
+                ClearScoreError(&totalError);
+                totalError.threshold = frameError.threshold = MAX_SCORE_DIFFERENCE;
+                auto outputFrame = &ptrScores.front();
+                std::vector<uint8_t *> inputFrame;
+                for (auto &ut : ptrUtterances) {
+                    inputFrame.push_back(&ut.front());
                 }
 
-                bool inferRequestFetched = false;
-                for (auto &inferRequest : inferRequests) {
+                std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> callPerfMap;
+
+                size_t frameIndex = 0;
+                uint32_t numFramesArkFile = numFrames;
+                numFrames += FLAGS_cw_l + FLAGS_cw_r;
+                uint32_t numFramesThisBatch{batchSize};
+
+                auto t0 = Time::now();
+                auto t1 = t0;
+
+                while (frameIndex <= numFrames) {
                     if (frameIndex == numFrames) {
-                        numFramesThisBatch = 1;
-                    } else {
-                        numFramesThisBatch = (numFrames - frameIndex < batchSize) ? (numFrames - frameIndex)
-                                                                                  : batchSize;
+                        if (std::find_if(inferRequests.begin(),
+                                         inferRequests.end(),
+                                         [&](InferRequestStruct x) { return (x.frameIndex != -1); }) ==
+                            inferRequests.end()) {
+                            break;
+                        }
                     }
 
-                    if (inferRequest.frameIndex != -1) {
-                        StatusCode code = inferRequest.inferRequest.Wait(
-                                InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
-
-                        if (code != StatusCode::OK) {
-                            if (!useHetero) continue;
-                            if (code != StatusCode::INFER_NOT_STARTED) continue;
+                    bool inferRequestFetched = false;
+                    for (auto &inferRequest : inferRequests) {
+                        if (frameIndex == numFrames) {
+                            numFramesThisBatch = 1;
+                        } else {
+                            numFramesThisBatch = (numFrames - frameIndex < batchSize) ? (numFrames - frameIndex)
+                                                                                      : batchSize;
                         }
 
-                        if (inferRequest.frameIndex >= 0) {
-                            if (!FLAGS_o.empty()) {
-                                outputFrame =
-                                        &ptrScores.front() + numScoresPerFrame * sizeof(float) * (inferRequest.frameIndex);
-                                MemoryBlob::CPtr moutput = as<MemoryBlob>(inferRequest.inferRequest.GetBlob(cOutputInfo.rbegin()->first));
-                                if (!moutput) {
-                                    throw std::logic_error("We expect output to be inherited from MemoryBlob, "
-                                                           "but by fact we were not able to cast output to MemoryBlob");
-                                }
-                                // locked memory holder should be alive all time while access to its buffer happens
-                                auto moutputHolder = moutput->rmap();
-                                auto byteSize = inferRequest.numFramesThisBatch * numScoresPerFrame * sizeof(float);
-                                std::memcpy(outputFrame,
-                                            moutputHolder.as<const void *>(),
-                                            byteSize);
-                            }
+                        if (inferRequest.frameIndex != -1) {
+                            StatusCode code = inferRequest.inferRequest.Wait(
+                                    InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
 
-                            if (!FLAGS_r.empty()) {
-                                Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(cOutputInfo.rbegin()->first);
-                                MemoryBlob::CPtr moutput = as<MemoryBlob>(outputBlob);
-                                if (!moutput) {
-                                    throw std::logic_error("We expect output to be inherited from MemoryBlob, "
-                                                           "but by fact we were not able to cast output to MemoryBlob");
-                                }
-                                // locked memory holder should be alive all time while access to its buffer happens
-                                auto moutputHolder = moutput->rmap();
-                                CompareScores(moutputHolder.as<float *>(),
-                                              &ptrReferenceScores[inferRequest.frameIndex *
-                                                                  numFrameElementsReference *
-                                                                  numBytesPerElementReference],
-                                              &frameError,
-                                              inferRequest.numFramesThisBatch,
-                                              numFrameElementsReference);
-                                UpdateScoreError(&frameError, &totalError);
+                            if (code != StatusCode::OK) {
+                                if (!useHetero) continue;
+                                if (code != StatusCode::INFER_NOT_STARTED) continue;
                             }
-                            if (FLAGS_pc) {
-                                // retrieve new counters
-                                getPerformanceCounters(inferRequest.inferRequest, callPerfMap);
-                                // summarize retrieved counters with all previous
-                                sumPerformanceCounters(callPerfMap, utterancePerfMap);
+                            ConstOutputsDataMap newOutputInfo;
+                            if (inferRequest.frameIndex >= 0) {
+                                if (!FLAGS_o.empty()) {
+                                    outputFrame =
+                                            &ptrScores.front() +
+                                            numScoresPerFrame * sizeof(float) * (inferRequest.frameIndex);
+                                    if (!outputs.empty()) {
+                                        newOutputInfo[outputs[next_output]] = cOutputInfo[outputs[next_output]];
+                                    } else {
+                                        newOutputInfo = cOutputInfo;
+                                    }
+                                    Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(newOutputInfo.rbegin()->first);
+                                    MemoryBlob::CPtr moutput = as<MemoryBlob>(outputBlob);
+
+                                    if (!moutput) {
+                                        throw std::logic_error("We expect output to be inherited from MemoryBlob, "
+                                                               "but in fact we were not able to cast output to MemoryBlob");
+                                    }
+                                    // locked memory holder should be alive all time while access to its buffer happens
+                                    auto moutputHolder = moutput->rmap();
+                                    auto byteSize =
+                                            inferRequest.numFramesThisBatch * numScoresPerFrame * sizeof(float);
+                                    std::memcpy(outputFrame,
+                                                moutputHolder.as<const void *>(),
+                                                byteSize);
+                                }
+                                if (!FLAGS_r.empty()) {
+                                    if (!outputs.empty()) {
+                                        newOutputInfo[outputs[next_output]] = cOutputInfo[outputs[next_output]];
+                                    } else {
+                                        newOutputInfo = cOutputInfo;
+                                    }
+                                    Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(newOutputInfo.rbegin()->first);
+                                    MemoryBlob::CPtr moutput = as<MemoryBlob>(outputBlob);
+                                    if (!moutput) {
+                                        throw std::logic_error("We expect output to be inherited from MemoryBlob, "
+                                                               "but in fact we were not able to cast output to MemoryBlob");
+                                    }
+                                    // locked memory holder should be alive all time while access to its buffer happens
+                                    auto moutputHolder = moutput->rmap();
+                                    CompareScores(moutputHolder.as<float *>(),
+                                                  &ptrReferenceScores[inferRequest.frameIndex *
+                                                                      numFrameElementsReference *
+                                                                      numBytesPerElementReference],
+                                                  &frameError,
+                                                  inferRequest.numFramesThisBatch,
+                                                  numFrameElementsReference);
+                                    UpdateScoreError(&frameError, &totalError);
+                                }
+                                if (FLAGS_pc) {
+                                    // retrieve new counters
+                                    getPerformanceCounters(inferRequest.inferRequest, callPerfMap);
+                                    // summarize retrieved counters with all previous
+                                    sumPerformanceCounters(callPerfMap, utterancePerfMap);
+                                }
                             }
                         }
-                    }
-
-                    if (frameIndex == numFrames) {
-                        inferRequest.frameIndex = -1;
-                        continue;
-                    }
-
-                    ptrInputBlobs.clear();
-                    for (auto& input : cInputInfo) {
-                        ptrInputBlobs.push_back(inferRequest.inferRequest.GetBlob(input.first));
-                    }
 
-                    for (size_t i = 0; i < numInputArkFiles; ++i) {
-                        MemoryBlob::Ptr minput = as<MemoryBlob>(ptrInputBlobs[i]);
-                        if (!minput) {
-                            slog::err << "We expect ptrInputBlobs[" << i << "] to be inherited from MemoryBlob, " <<
-                                "but by fact we were not able to cast input blob to MemoryBlob" << slog::endl;
-                            return 1;
+                        if (frameIndex == numFrames) {
+                            inferRequest.frameIndex = -1;
+                            continue;
                         }
-                        // locked memory holder should be alive all time while access to its buffer happens
-                        auto minputHolder = minput->wmap();
 
-                        std::memcpy(minputHolder.as<void*>(),
-                                    inputFrame[i],
-                                    minput  ->byteSize());
-                    }
+                        if (FLAGS_iname.empty()) {
+                            size_t num_files = FLAGS_iname.empty() ? numInputArkFiles : ptrInputBlobs.size();
+                            for (size_t i = 0; i < num_files; ++i) {
+                                MemoryBlob::Ptr minput = as<MemoryBlob>(ptrInputBlobs[i]);
+                                if (!minput) {
+                                    slog::err << "We expect ptrInputBlobs[" << i
+                                              << "] to be inherited from MemoryBlob, " <<
+                                              "but in fact we were not able to cast input blob to MemoryBlob"
+                                              << slog::endl;
+                                    return 1;
+                                }
+                                // locked memory holder should be alive all time while access to its buffer happens
+                                auto minputHolder = minput->wmap();
 
-                    int index = static_cast<int>(frameIndex) - (FLAGS_cw_l + FLAGS_cw_r);
-                    inferRequest.inferRequest.StartAsync();
-                    inferRequest.frameIndex = index < 0 ? -2 : index;
-                    inferRequest.numFramesThisBatch = numFramesThisBatch;
+                                std::memcpy(minputHolder.as<void *>(),
+                                            inputFrame[i],
+                                            minput->byteSize());
+                            }
+                        }
 
-                    frameIndex += numFramesThisBatch;
-                    for (size_t j = 0; j < inputArkFiles.size(); j++) {
-                        if (FLAGS_cw_l > 0 || FLAGS_cw_r > 0) {
-                            int idx = frameIndex - FLAGS_cw_l;
-                            if (idx > 0 && idx < static_cast<int>(numFramesArkFile)) {
+                        int index = static_cast<int>(frameIndex) - (FLAGS_cw_l + FLAGS_cw_r);
+                        inferRequest.inferRequest.StartAsync();
+                        inferRequest.frameIndex = index < 0 ? -2 : index;
+                        inferRequest.numFramesThisBatch = numFramesThisBatch;
+
+                        frameIndex += numFramesThisBatch;
+                        for (size_t j = 0; j < inputArkFiles.size(); j++) {
+                            if (FLAGS_cw_l > 0 || FLAGS_cw_r > 0) {
+                                int idx = frameIndex - FLAGS_cw_l;
+                                if (idx > 0 && idx < static_cast<int>(numFramesArkFile)) {
+                                    inputFrame[j] += sizeof(float) * numFrameElementsInput[j] * numFramesThisBatch;
+                                } else if (idx >= static_cast<int>(numFramesArkFile)) {
+                                    inputFrame[j] = &ptrUtterances[j].front() +
+                                                    (numFramesArkFile - 1) * sizeof(float) * numFrameElementsInput[j] *
+                                                    numFramesThisBatch;
+                                } else if (idx <= 0) {
+                                    inputFrame[j] = &ptrUtterances[j].front();
+                                }
+                            } else {
                                 inputFrame[j] += sizeof(float) * numFrameElementsInput[j] * numFramesThisBatch;
-                            } else if (idx >= static_cast<int>(numFramesArkFile)) {
-                                inputFrame[j] = &ptrUtterances[j].front() +
-                                        (numFramesArkFile - 1) * sizeof(float) * numFrameElementsInput[j] * numFramesThisBatch;
-                            } else if (idx <= 0) {
-                                inputFrame[j] = &ptrUtterances[j].front();
                             }
-                        } else {
-                            inputFrame[j] += sizeof(float) * numFrameElementsInput[j] * numFramesThisBatch;
                         }
+                        inferRequestFetched |= true;
+                    }
+                    if (!inferRequestFetched) {
+                        std::this_thread::sleep_for(std::chrono::milliseconds(1));
+                        continue;
                     }
-                    inferRequestFetched |= true;
-                }
-                if (!inferRequestFetched) {
-                    std::this_thread::sleep_for(std::chrono::milliseconds(1));
-                    continue;
                 }
-            }
-            t1 = Time::now();
+                t1 = Time::now();
 
-            fsec fs = t1 - t0;
-            ms d = std::chrono::duration_cast<ms>(fs);
-            totalTime += d.count();
+                fsec fs = t1 - t0;
+                ms d = std::chrono::duration_cast<ms>(fs);
+                totalTime += d.count();
 
-            // resetting state between utterances
-            for (auto &&state : executableNet.QueryState()) {
-                state.Reset();
-            }
+                // resetting state between utterances
+                for (auto &&state : executableNet.QueryState()) {
+                    state.Reset();
+                }
 
-            if (!FLAGS_o.empty()) {
-                bool shouldAppend = (utteranceIndex == 0) ? false : true;
-                SaveKaldiArkArray(FLAGS_o.c_str(), shouldAppend, uttName, &ptrScores.front(),
-                                  numFramesArkFile, numScoresPerFrame);
-            }
+                if (!FLAGS_o.empty()) {
+                    bool shouldAppend = (utteranceIndex == 0) ? false : true;
+                    SaveKaldiArkArray(output_name_files[next_output].c_str(), shouldAppend, uttName, &ptrScores.front(),
+                                      numFramesArkFile, numScoresPerFrame);
+                }
 
-            /** Show performance results **/
-            std::cout << "Total time in Infer (HW and SW):\t" << totalTime << " ms"
-                      << std::endl;
-            std::cout << "Frames in utterance:\t\t\t" << numFrames << " frames"
-                      << std::endl;
-            std::cout << "Average Infer time per frame:\t\t" << totalTime / static_cast<double>(numFrames) << " ms"
-                      << std::endl;
-            if (FLAGS_pc) {
-                // print
-                printPerformanceCounters(utterancePerfMap, frameIndex, std::cout, getFullDeviceName(ie, FLAGS_d));
-            }
-            if (!FLAGS_r.empty()) {
-                printReferenceCompareResults(totalError, numFrames, std::cout);
+                /** Show performance results **/
+                std::cout << "Total time in Infer (HW and SW):\t" << totalTime << " ms"
+                          << std::endl;
+                std::cout << "Frames in utterance:\t\t\t" << numFrames << " frames"
+                          << std::endl;
+                std::cout << "Average Infer time per frame:\t\t" << totalTime / static_cast<double>(numFrames) << " ms"
+                          << std::endl;
+                if (FLAGS_pc) {
+                    // print
+                    printPerformanceCounters(utterancePerfMap, frameIndex, std::cout, getFullDeviceName(ie, FLAGS_d));
+                }
+                if (!FLAGS_r.empty()) {
+                    printReferenceCompareResults(totalError, numFrames, std::cout);
+                }
+                std::cout << "End of Utterance " << utteranceIndex << std::endl << std::endl;
             }
-            std::cout << "End of Utterance " << utteranceIndex << std::endl << std::endl;
         }
         // -----------------------------------------------------------------------------------------------------
     }
index a76a51c..f608158 100644 (file)
@@ -81,6 +81,16 @@ static const char context_window_message_r[] = "Optional. Number of frames for r
                                                "Works only with context window networks."
                                                " If you use the cw_r or cw_l flag, then batch size and nthreads arguments are ignored.";
 
+/// @brief message for output layer names
+static const char output_layer_names_message[] = "Optional. Layer names for output blobs. " \
+                                          "The names are separated with \",\" " \
+                                          "Example: input1:port,input2:port ";
+
+/// @brief message for inputs layer names
+static const char input_layer_names_message[] = "Optional. Layer names for input blobs. " \
+                                          "The names are separated with \",\" " \
+                                          "Example: Input1,Input2 ";
+
 /// \brief Define flag for showing help message <br>
 DEFINE_bool(h, false, help_message);
 
@@ -145,6 +155,12 @@ DEFINE_int32(cw_r, 0, context_window_message_r);
 /// @brief Left context window size (default 0)
 DEFINE_int32(cw_l, 0, context_window_message_l);
 
+/// @brief Output layer name
+DEFINE_string(oname, "", output_layer_names_message);
+
+/// @brief Input layer name
+DEFINE_string(iname, "", input_layer_names_message);
+
 /**
  * \brief This function show a help message
  */
@@ -173,5 +189,7 @@ static void showUsage() {
     std::cout << "    -nthreads \"<integer>\"   " << infer_num_threads_message << std::endl;
     std::cout << "    -cw_l \"<integer>\"       " << context_window_message_l << std::endl;
     std::cout << "    -cw_r \"<integer>\"       " << context_window_message_r << std::endl;
+    std::cout << "    -oname \"<string>\"       " << output_layer_names_message << std::endl;
+    std::cout << "    -iname \"<string>\"       " << input_layer_names_message << std::endl;
 }