Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / samples / speech_sample / main.cpp
index e0dc005..4b7115a 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
 // SPDX-License-Identifier: Apache-2.0
 //
 
@@ -25,6 +25,7 @@
 #include <samples/common.hpp>
 #include <samples/slog.hpp>
 #include <samples/args_helper.hpp>
+#include <ext_list.hpp>
 
 #ifndef ALIGN
 #define ALIGN(memSize, pad)   ((static_cast<int>((memSize) + pad - 1) / pad) * pad)
@@ -51,6 +52,12 @@ typedef struct {
     float sumSquaredRelError;
 } score_error_t;
 
+struct InferRequestStruct {
+    InferRequest inferRequest;
+    int frameIndex;
+    uint32_t numFramesThisBatch;
+};
+
 void GetKaldiArkInfo(const char *fileName,
                      uint32_t numArrayToFindSize,
                      uint32_t *ptrNumArrays,
@@ -119,7 +126,6 @@ void LoadKaldiArkArray(const char *fileName, uint32_t arrayIndex, std::string &p
             in_file.read(reinterpret_cast<char *>(ptrNumRows), sizeof(uint32_t));        // read number of rows
             std::getline(in_file, line, '\4');                                            // read control-D
             in_file.read(reinterpret_cast<char *>(ptrNumColumns), sizeof(uint32_t));    // read number of columns
-            size_t willWrite = *ptrNumRows * *ptrNumColumns * sizeof(float);
             in_file.read(reinterpret_cast<char *>(&memory.front()),
                          *ptrNumRows * *ptrNumColumns * sizeof(float));  // read array data
         }
@@ -286,7 +292,6 @@ inline void native_cpuid(unsigned int *eax, unsigned int *ebx,
 
 // return GNA module frequency in MHz
 float getGnaFrequencyMHz() {
-    uint32_t level = 0;
     uint32_t eax = 1;
     uint32_t ebx = 0;
     uint32_t ecx = 0;
@@ -353,12 +358,11 @@ void printPerformanceCounters(std::map<std::string,
 
     for (const auto &it : utterancePerfMap) {
         std::string const &counter_name = it.first;
-        float current_units = it.second.realTime_uSec;
+        float current_units = static_cast<float>(it.second.realTime_uSec);
         float call_units = current_units / callsNum;
-        float freq = 1.0;
         // if GNA HW counters
         // get frequency of GNA module
-        freq = getGnaFrequencyMHz();
+        float freq = getGnaFrequencyMHz();
         current_units /= freq * 1000;
         call_units /= freq;
         stream << std::setw(30) << std::left << counter_name.substr(4, counter_name.size() - 1);
@@ -414,9 +418,20 @@ bool ParseAndCheckCommandLine(int argc, char *argv[]) {
         throw std::logic_error("Only one of -m and -rg is allowed.");
     }
 
-    if ((FLAGS_d.compare("GPU") != 0) && (FLAGS_d.compare("CPU") != 0) && (FLAGS_d.compare("GNA_AUTO") != 0) &&
-        (FLAGS_d.compare("GNA_HW") != 0)
-        && (FLAGS_d.compare("GNA_SW") != 0) && (FLAGS_d.compare("GNA_SW_EXACT") != 0)) {
+    std::vector<std::string> possibleDeviceTypes = {
+            "CPU",
+            "GPU",
+            "GNA_AUTO",
+            "GNA_HW",
+            "GNA_SW_EXACT",
+            "GNA_SW",
+            "HETERO:GNA,CPU",
+            "HETERO:GNA_HW,CPU",
+            "HETERO:GNA_SW_EXACT,CPU",
+            "HETERO:GNA_SW,CPU",
+    };
+
+    if (std::find(possibleDeviceTypes.begin(), possibleDeviceTypes.end(), FLAGS_d) == possibleDeviceTypes.end()) {
         throw std::logic_error("Specified device is not supported.");
     }
 
@@ -447,6 +462,10 @@ bool ParseAndCheckCommandLine(int argc, char *argv[]) {
         throw std::logic_error("Not valid value for 'nthreads' argument. It should be > 0 ");
     }
 
+    if (FLAGS_cw < 0) {
+        throw std::logic_error("Not valid value for 'cw' argument. It should be > 0 ");
+    }
+
     return true;
 }
 
@@ -468,10 +487,14 @@ int main(int argc, char *argv[]) {
             slog::info << "No extensions provided" << slog::endl;
         }
 
-        bool useGna = (FLAGS_d.find("GNA") != std::string::npos);
-        auto deviceStr = FLAGS_d.substr(0, (FLAGS_d.find("_")));
+        auto isFeature = [&](const std::string xFeature) { return FLAGS_d.find(xFeature) != std::string::npos; };
+
+        bool useGna = isFeature("GNA");
+        bool useHetero = isFeature("HETERO");
+        std::string deviceStr =
+                useHetero && useGna ? "HETERO:GNA,CPU" : FLAGS_d.substr(0, (FLAGS_d.find("_")));
         float scaleFactorInput = static_cast<float>(FLAGS_sf);
-        uint32_t batchSize = (uint32_t) FLAGS_bs;
+        uint32_t batchSize = FLAGS_cw > 0 ? 1 : (uint32_t) FLAGS_bs;
         /** Extract input ark file name **/
         std::string inputArkName = fileNameNoExt(FLAGS_i) + ".ark";
 
@@ -484,7 +507,7 @@ int main(int argc, char *argv[]) {
         // --------------------------- 1. Load Plugin for inference engine -------------------------------------
         slog::info << "Loading plugin" << slog::endl;
         /** Loading plugin for device **/
-        InferencePlugin plugin = PluginDispatcher({FLAGS_pp, "../../../lib/intel64", ""}).getPluginByDevice(deviceStr);
+        InferencePlugin plugin = PluginDispatcher({FLAGS_pp}).getPluginByDevice(deviceStr);
 
         /** Printing plugin version **/
         std::cout << plugin.GetVersion() << std::endl << std::endl;
@@ -514,9 +537,20 @@ int main(int argc, char *argv[]) {
         /** Setting plugin parameter for per layer metrics **/
         std::map<std::string, std::string> gnaPluginConfig;
         std::map<std::string, std::string> genericPluginConfig;
-        if (FLAGS_d.compare("CPU") != 0) {
-            gnaPluginConfig[GNAConfigParams::KEY_GNA_DEVICE_MODE] = FLAGS_d;
+        if (useGna) {
+            std::string gnaDevice =
+                    useHetero ? FLAGS_d.substr(FLAGS_d.find("GNA"), FLAGS_d.find(",") - FLAGS_d.find("GNA")) : FLAGS_d;
+            gnaPluginConfig[GNAConfigParams::KEY_GNA_DEVICE_MODE] =
+                    gnaDevice.find("_") == std::string::npos ? "GNA_AUTO" : gnaDevice;
+        } else if (plugin.GetVersion()->description == std::string("MKLDNNPlugin")) {
+            /**
+             * cpu_extensions library is compiled from "extension" folder containing
+             * custom MKLDNNPlugin layer implementations. These layers are not supported
+             * by mkldnn, but they can be useful for inferring custom topologies.
+            **/
+            plugin.AddExtension(std::make_shared<Extensions::Cpu::CpuExtensions>());
         }
+
         if (FLAGS_pc) {
             genericPluginConfig[PluginConfigParams::KEY_PERF_COUNT] = PluginConfigParams::YES;
         }
@@ -550,7 +584,7 @@ int main(int argc, char *argv[]) {
             gnaPluginConfig[GNAConfigParams::KEY_GNA_PRECISION] = "I16";
         }
 
-        gnaPluginConfig[GNAConfigParams::KEY_GNA_LIB_N_THREADS] = std::to_string(FLAGS_nthreads);
+        gnaPluginConfig[GNAConfigParams::KEY_GNA_LIB_N_THREADS] = std::to_string(FLAGS_cw > 0 ? 1 : FLAGS_nthreads);
         gnaPluginConfig[GNA_CONFIG_KEY(COMPACT_MODE)] = CONFIG_VALUE(NO);
         // -----------------------------------------------------------------------------------------------------
 
@@ -568,6 +602,7 @@ int main(int argc, char *argv[]) {
         }
         auto t0 = Time::now();
         ExecutableNetwork executableNet;
+
         if (!FLAGS_m.empty()) {
             slog::info << "Loading model to the plugin" << slog::endl;
             executableNet = plugin.LoadNetwork(netBuilder.getNetwork(), genericPluginConfig);
@@ -576,7 +611,6 @@ int main(int argc, char *argv[]) {
             executableNet = plugin.ImportNetwork(FLAGS_rg.c_str(), genericPluginConfig);
         }
 
-
         ms loadTime = std::chrono::duration_cast<ms>(Time::now() - t0);
         slog::info << "Model loading time " << loadTime.count() << " ms" << slog::endl;
 
@@ -595,9 +629,9 @@ int main(int argc, char *argv[]) {
             return 0;
         }
 
-        std::vector<std::pair<InferRequest, size_t>> inferRequests(FLAGS_nthreads);
+        std::vector<InferRequestStruct> inferRequests(FLAGS_cw > 0 ? 1 : FLAGS_nthreads);
         for (auto& inferRequest : inferRequests) {
-            inferRequest = {executableNet.CreateInferRequest(), -1};
+            inferRequest = {executableNet.CreateInferRequest(), -1, batchSize};
         }
         // -----------------------------------------------------------------------------------------------------
 
@@ -614,7 +648,7 @@ int main(int argc, char *argv[]) {
             throw std::logic_error("Sample supports only topologies with  1 input");
         }
 
-        Blob::Ptr ptrInputBlob = inferRequests[0].first.GetBlob(cInputInfo.begin()->first);
+        Blob::Ptr ptrInputBlob = inferRequests[0].inferRequest.GetBlob(cInputInfo.begin()->first);
 
         /** configure input precision if model loaded from IR **/
         for (auto &item : inputInfo) {
@@ -632,7 +666,7 @@ int main(int argc, char *argv[]) {
             outputInfo = netBuilder.getNetwork().getOutputsInfo();
         }
 
-        Blob::Ptr ptrOutputBlob = inferRequests[0].first.GetBlob(cOutputInfo.begin()->first);
+        Blob::Ptr ptrOutputBlob = inferRequests[0].inferRequest.GetBlob(cOutputInfo.begin()->first);
 
         for (auto &item : outputInfo) {
             DataPtr outData = item.second;
@@ -699,22 +733,20 @@ int main(int argc, char *argv[]) {
             auto inputFrame = &ptrUtterance.front();
             auto outputFrame = &ptrScores.front();
 
-            size_t frameIndex{0};
+            std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> callPerfMap;
+
+            size_t frameIndex = 0;
+            numFrames += 2 * FLAGS_cw;
             uint32_t numFramesThisBatch{batchSize};
 
             auto t0 = Time::now();
             auto t1 = t0;
 
-            // Doing inference
             while (frameIndex <= numFrames) {
                 if (frameIndex == numFrames) {
-                    bool hasRequests = false;
-                    for (auto &inferRequest : inferRequests) {
-                        if (inferRequest.second != -1) {
-                            hasRequests = true;
-                        }
-                    }
-                    if (!hasRequests) {
+                    if (std::find_if(inferRequests.begin(),
+                            inferRequests.end(),
+                            [&](InferRequestStruct x) { return (x.frameIndex != -1); } ) == inferRequests.end()) {
                         break;
                     }
                 }
@@ -724,54 +756,79 @@ int main(int argc, char *argv[]) {
                     if (frameIndex == numFrames) {
                         numFramesThisBatch = 1;
                     } else {
-                        numFramesThisBatch = (numFrames - frameIndex < batchSize) ? (numFrames - frameIndex) : batchSize;
+                        numFramesThisBatch = (numFrames - frameIndex < batchSize) ? (numFrames - frameIndex)
+                                                                                  : batchSize;
                     }
 
-                    if (inferRequest.second != -1) {
-                        StatusCode code = inferRequest.first.Wait(
+                    if (inferRequest.frameIndex != -1) {
+                        StatusCode code = inferRequest.inferRequest.Wait(
                                 InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
-
                         if (code != StatusCode::OK) {
-                            continue;
+                            if (!useHetero) continue;
+                            if (code != StatusCode::INFER_NOT_STARTED) continue;
                         }
 
-                        if (!FLAGS_o.empty()) {
-                            Blob::Ptr outputBlob = inferRequest.first.GetBlob(cOutputInfo.begin()->first);
-                            std::memcpy(outputFrame,
-                                        outputBlob->buffer(),
-                                        outputBlob->byteSize());
-                            outputFrame += numScoresPerFrame * sizeof(float);
-                        }
-
-                        if (!FLAGS_r.empty()) {
-                            Blob::Ptr outputBlob = inferRequest.first.GetBlob(cOutputInfo.begin()->first);
-                            CompareScores(outputBlob->buffer().as<float *>(),
-                                          &ptrReferenceScores[inferRequest.second *
-                                                              numFrameElementsReference *
-                                                              numBytesPerElementReference],
-                                          &frameError,
-                                          numFramesThisBatch,
-                                          numFrameElementsReference);
-                            UpdateScoreError(&frameError, &totalError);
+                        if (inferRequest.frameIndex >= 0) {
+                            if (!FLAGS_o.empty()) {
+                                outputFrame =
+                                        &ptrScores.front() + numScoresPerFrame * sizeof(float) * (inferRequest.frameIndex);
+                                Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(cOutputInfo.begin()->first);
+                                auto byteSize = inferRequest.numFramesThisBatch * numScoresPerFrame * sizeof(float);
+                                std::memcpy(outputFrame,
+                                            outputBlob->buffer(),
+                                            byteSize);
+                            }
+
+                            if (!FLAGS_r.empty()) {
+                                Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(cOutputInfo.begin()->first);
+                                CompareScores(outputBlob->buffer().as<float *>(),
+                                              &ptrReferenceScores[inferRequest.frameIndex *
+                                                                  numFrameElementsReference *
+                                                                  numBytesPerElementReference],
+                                              &frameError,
+                                              inferRequest.numFramesThisBatch,
+                                              numFrameElementsReference);
+                                UpdateScoreError(&frameError, &totalError);
+                            }
+                            if (FLAGS_pc) {
+                                // retrive new counters
+                                getPerformanceCounters(inferRequest.inferRequest, callPerfMap);
+                                // summarize retrived counters with all previous
+                                sumPerformanceCounters(callPerfMap, utterancePerfMap);
+                            }
                         }
                     }
 
-                    inferRequest.second = -1;
-
                     if (frameIndex == numFrames) {
+                        inferRequest.frameIndex = -1;
                         continue;
                     }
 
-                    Blob::Ptr inputBlob = inferRequest.first.GetBlob(cInputInfo.begin()->first);
+                    Blob::Ptr inputBlob = inferRequest.inferRequest.GetBlob(cInputInfo.begin()->first);
+
                     std::memcpy(inputBlob->buffer(),
                                 inputFrame,
                                 inputBlob->byteSize());
 
-                    inferRequest.first.StartAsync();
+                    auto index = frameIndex - 2 * FLAGS_cw;
+                    inferRequest.inferRequest.StartAsync();
+                    inferRequest.frameIndex = index < 0 ? -2 : index;
+                    inferRequest.numFramesThisBatch = numFramesThisBatch;
 
-                    inferRequest.second = frameIndex;
                     frameIndex += numFramesThisBatch;
-                    inputFrame += sizeof(float) * numFrameElementsInput * numFramesThisBatch;
+
+                    if (FLAGS_cw > 0) {
+                        int i = frameIndex - FLAGS_cw;
+                        if (i > 0 && i < static_cast<int>(numFrames)) {
+                            inputFrame += sizeof(float) * numFrameElementsInput * numFramesThisBatch;
+                        } else if (i >= static_cast<int>(numFrames)) {
+                            inputFrame = &ptrUtterance.front() +
+                                         (numFrames - 1) * sizeof(float) * numFrameElementsInput *
+                                         numFramesThisBatch;
+                        }
+                    } else {
+                        inputFrame += sizeof(float) * numFrameElementsInput * numFramesThisBatch;
+                    }
                     inferRequestFetched |= true;
                 }
 
@@ -779,16 +836,6 @@ int main(int argc, char *argv[]) {
                     std::this_thread::sleep_for(std::chrono::milliseconds(1));
                     continue;
                 }
-
-                if (FLAGS_pc) {
-                    std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> callPerfMap;
-                    // retrive new counters
-                    for (auto inferRequest : inferRequests) {
-                        getPerformanceCounters(inferRequest.first, callPerfMap);
-                        // summarize retrived counters with all previous
-                        sumPerformanceCounters(callPerfMap, utterancePerfMap);
-                    }
-                }
             }
             t1 = Time::now();