Return REQUEST_NOT_READY if GNA Wait returns busy status (#2401)
authorKrzysztof Bruniecki <krzysztof.bruniecki@intel.com>
Mon, 28 Sep 2020 07:22:23 +0000 (09:22 +0200)
committerGitHub <noreply@github.com>
Mon, 28 Sep 2020 07:22:23 +0000 (10:22 +0300)
* Return REQUEST_NOT_READY if GNA Wait returns busy status

* Apply fixes from review

inference-engine/src/gna_plugin/gna_device.cpp
inference-engine/src/gna_plugin/gna_device.hpp
inference-engine/src/gna_plugin/gna_infer_request.hpp
inference-engine/src/gna_plugin/gna_plugin.cpp
inference-engine/src/gna_plugin/gna_plugin.hpp

index 52fe936a097b44cab11a36cb8386b3fa964f4202..97d4026066a1b3e5aceb2a1cdc8beeff032e8127 100644 (file)
@@ -272,11 +272,14 @@ const std::map <const std::pair<Gna2OperationType, int32_t>, const std::string>
 };
 #endif
 
-bool GNADeviceHelper::wait(uint32_t reqId, int64_t millisTimeout) {
+GnaWaitStatus GNADeviceHelper::wait(uint32_t reqId, int64_t millisTimeout) {
 #if GNA_LIB_VER == 2
     const auto status = Gna2RequestWait(reqId, millisTimeout);
     if (status == Gna2StatusDriverQoSTimeoutExceeded) {
-        return false;
+        return GNA_REQUEST_ABORTED;
+    }
+    if (status == Gna2StatusWarningDeviceBusy) {
+        return GNA_REQUEST_PENDING;
     }
     checkGna2Status(status);
     unwaitedRequestIds.erase(std::remove(unwaitedRequestIds.begin(), unwaitedRequestIds.end(), reqId));
@@ -289,7 +292,7 @@ bool GNADeviceHelper::wait(uint32_t reqId, int64_t millisTimeout) {
     checkStatus();
 #endif
     updateGnaPerfCounters();
-    return true;
+    return GNA_REQUEST_COMPLETED;
 }
 
 #if GNA_LIB_VER == 1
index 1a4292253b6c85d3cfcf5a505533d29925c21b35..7b35f3c4a64cfbd4ae850c90015e19e41641ec73 100644 (file)
 #include "gna-api-instrumentation.h"
 #endif
 
+enum GnaWaitStatus : int {
+    GNA_REQUEST_COMPLETED = 0,  // and removed from GNA library queue
+    GNA_REQUEST_ABORTED = 1,    // for QoS purposes
+    GNA_REQUEST_PENDING = 2     // for device busy purposes
+};
 
 /**
  * holds gna - style handle in RAII way
@@ -115,7 +120,7 @@ public:
     static void checkGna2Status(Gna2Status status);
     static void checkGna2Status(Gna2Status status, const Gna2Model& gnaModel);
 #endif
-    bool wait(uint32_t id, int64_t millisTimeout = MAX_TIMEOUT);
+    GnaWaitStatus wait(uint32_t id, int64_t millisTimeout = MAX_TIMEOUT);
 
     struct DumpResult {
 #if GNA_LIB_VER == 2
index 5d34db4bea714a3dccb30c64f6f01d96eae3d189..defbc16cdc736e110f472799e9d753156a4a6248 100644 (file)
@@ -95,20 +95,21 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
             THROW_IE_EXCEPTION << PARAMETER_MISMATCH_str;
         }
 
-        bool qosOK;
         if (millis_timeout == InferenceEngine::IInferRequest::WaitMode::RESULT_READY) {
-            qosOK = plg->Wait(inferRequestIdx);
-        } else {
-            qosOK = plg->WaitFor(inferRequestIdx, millis_timeout);
+            millis_timeout = MAX_TIMEOUT;
         }
+        const auto waitStatus = plg->WaitFor(inferRequestIdx, millis_timeout);
 
-        if (qosOK) {
-            return InferenceEngine::OK;
-        } else {
+        if (waitStatus == GNA_REQUEST_PENDING) {
+            // request is still pending so Wait() is needed once again
+            return InferenceEngine::RESULT_NOT_READY;
+        }
+        if (waitStatus == GNA_REQUEST_ABORTED) {
             // need to preserve invalid state here to avoid next Wait() from clearing it
             inferRequestIdx = -1;
             return InferenceEngine::INFER_NOT_STARTED;
         }
+        return InferenceEngine::OK;
     }
 };
 }  // namespace GNAPluginNS
index b9861b86cdb8185ee0a939162465cd43a95eecde..92b6a695eea31b1a8332dc15125f6b96c923abbb 100644 (file)
@@ -976,21 +976,26 @@ uint32_t GNAPlugin::QueueInference(const InferenceEngine::BlobMap &inputs, Infer
 }
 
 bool GNAPlugin::Wait(uint32_t request_idx) {
-    return WaitFor(request_idx, MAX_TIMEOUT);
+    return GNA_REQUEST_COMPLETED == WaitFor(request_idx, MAX_TIMEOUT);
 }
 
-bool GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
+GnaWaitStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
 #if GNA_LIB_VER == 2
     auto& nnets = gnaRequestConfigToRequestIdMap;
 #endif
-    if (nnets.size() <= request_idx) return true;    // TODO: GNA2: check whether necessary
+    // TODO: GNA2: check whether necessary
+    if (nnets.size() <= request_idx) return GNA_REQUEST_COMPLETED;
     // already synced TODO: might be copy required ???
-    if (std::get<1>(nnets[request_idx]) == -1) return true;
+    if (std::get<1>(nnets[request_idx]) == -1) return GNA_REQUEST_COMPLETED;
 
     if (gnadevice) {
-        if (!gnadevice->wait(std::get<1>(nnets[request_idx]), millisTimeout)) {
+        const auto waitStatus = gnadevice->wait(std::get<1>(nnets[request_idx]), millisTimeout);
+        if (waitStatus == GNA_REQUEST_ABORTED) {
             std::get<1>(nnets[request_idx]) = -1;
-            return false;
+            return GNA_REQUEST_ABORTED;
+        }
+        if (waitStatus == GNA_REQUEST_PENDING) {
+            return GNA_REQUEST_PENDING;
         }
     }
 
@@ -1090,7 +1095,7 @@ bool GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
 
         output_idx++;
     }
-    return true;
+    return GNA_REQUEST_COMPLETED;
 }
 
 void GNAPlugin::Reset() {
index 54bf674d69495bd031b3840037c29cbf17004d07..99eda6c07d54f98f4cd926ae99af28bfe2564403 100644 (file)
@@ -119,7 +119,7 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
                       InferenceEngine::QueryNetworkResult &res) const override;
     uint32_t QueueInference(const InferenceEngine::BlobMap &input, InferenceEngine::BlobMap &result);
     bool Wait(uint32_t idx);
-    bool WaitFor(uint32_t idx, int64_t millisTimeout);
+    GnaWaitStatus WaitFor(uint32_t idx, int64_t millisTimeout);
 
     InferenceEngine::Parameter GetConfig(const std::string& name,
                                          const std::map<std::string, InferenceEngine::Parameter> & options) const override;