Default Infer leads to correct infer request state (#1562)
authorAnton Pankratv <anton.pankratov@intel.com>
Thu, 17 Sep 2020 08:57:11 +0000 (11:57 +0300)
committerGitHub <noreply@github.com>
Thu, 17 Sep 2020 08:57:11 +0000 (11:57 +0300)
inference-engine/src/plugin_api/cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp

index db3ea04..a0a0a89 100644 (file)
@@ -6,6 +6,7 @@
 
 #include <threading/ie_immediate_executor.hpp>
 #include <threading/ie_itask_executor.hpp>
+#include <threading/ie_istreams_executor.hpp>
 
 #include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
 #include <cpp_interfaces/impl/ie_infer_async_request_thread_safe_internal.hpp>
@@ -72,12 +73,21 @@ public:
      */
     AsyncInferRequestThreadSafeDefault(const InferRequestInternal::Ptr& request,
                                        const ITaskExecutor::Ptr& taskExecutor,
-                                       const ITaskExecutor::Ptr& callbackExecutor)
-        : _syncRequest {request},
-          _requestExecutor {taskExecutor},
-          _callbackExecutor {callbackExecutor},
-          _pipeline {{taskExecutor, [this] {_syncRequest->Infer();}}},
-          _syncPipeline{{std::make_shared<ImmediateExecutor>(), [this] {_syncRequest->Infer();}}} {
+                                       const ITaskExecutor::Ptr& callbackExecutor) :
+        _syncRequest {request},
+        _requestExecutor {taskExecutor},
+        _callbackExecutor {callbackExecutor},
+        _pipeline {{taskExecutor, [this] {_syncRequest->Infer();}}},
+        _syncPipeline{{std::make_shared<ImmediateExecutor>(), [this] {_syncRequest->Infer();}}} {
+        auto streamsExecutor = std::dynamic_pointer_cast<IStreamsExecutor>(taskExecutor);
+        if (streamsExecutor != nullptr) {
+            struct ImmediateStreamsExecutor : public InferenceEngine::ITaskExecutor {
+                explicit ImmediateStreamsExecutor(const IStreamsExecutor::Ptr& streamsExecutor) : _streamsExecutor{streamsExecutor} {}
+                void run(InferenceEngine::Task task) override {_streamsExecutor->Execute(std::move(task));}
+                IStreamsExecutor::Ptr _streamsExecutor;
+            };
+            _syncPipeline = {{std::make_shared<ImmediateStreamsExecutor>(std::move(streamsExecutor)), [this] {_syncRequest->Infer();}}};
+        }
     }
 
     /**
@@ -228,6 +238,7 @@ protected:
         DisableCallbackGuard disableCallbackGuard{_callback};
         _syncRequest->checkBlobs();
         RunFirstStage(_syncPipeline.begin(), _syncPipeline.end(), _syncCallbackExecutor);
+        // If we have exception we should extract it from future using Wait() method
         Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
     }