Add sync/async executors for infer request
authorTolyaTalamanov <anatoliy.talamanov@intel.com>
Thu, 29 Sep 2022 09:47:53 +0000 (09:47 +0000)
committerTolyaTalamanov <anatoliy.talamanov@intel.com>
Mon, 3 Oct 2022 09:43:50 +0000 (09:43 +0000)
modules/gapi/src/backends/ie/giebackend.cpp

index be44d36..ccb5953 100644 (file)
@@ -826,8 +826,10 @@ std::vector<InferenceEngine::InferRequest> cv::gimpl::ie::IECompiled::createInfe
     return requests;
 }
 
-class cv::gimpl::ie::RequestPool {
+class IInferExecutor {
 public:
+    using Ptr             = std::shared_ptr<IInferExecutor>;
+    using NotifyCallbackF = std::function<void()>;
     using SetInputDataF   = std::function<void(InferenceEngine::InferRequest&)>;
     using ReadOutputDataF = std::function<void(InferenceEngine::InferRequest&, InferenceEngine::StatusCode)>;
 
@@ -835,41 +837,106 @@ public:
     // SetInputDataF - function which set input data.
     // ReadOutputDataF - function which read output data.
     struct Task {
-        SetInputDataF set_input_data;
+        SetInputDataF   set_input_data;
         ReadOutputDataF read_output_data;
     };
 
-    explicit RequestPool(std::vector<InferenceEngine::InferRequest>&& requests);
+    IInferExecutor(IE::InferRequest request, NotifyCallbackF notify)
+        : m_request(std::move(request)),
+          m_notify(std::move(notify)) {
+    };
 
-    void execute(Task&& t);
-    void waitAll();
+    virtual void execute(const Task& task) = 0;
+    virtual ~IInferExecutor() = default;
+
+protected:
+    IE::InferRequest m_request;
+    NotifyCallbackF  m_notify;
+};
+
+class SyncInferExecutor : public IInferExecutor {
+    using IInferExecutor::IInferExecutor;
+    virtual void execute(const IInferExecutor::Task& task) override;
+};
+
+void SyncInferExecutor::execute(const IInferExecutor::Task& task) {
+    try {
+        task.set_input_data(m_request);
+        m_request.Infer();
+        task.read_output_data(m_request, IE::StatusCode::OK);
+    } catch (...) {
+        m_notify();
+        throw;
+    }
+    // NB: Notify pool that executor has finished.
+    m_notify();
+}
+
+class AsyncInferExecutor : public IInferExecutor {
+public:
+    using IInferExecutor::IInferExecutor;
+    virtual void execute(const IInferExecutor::Task& task) override;
 
 private:
     void callback(Task task,
-                  size_t id,
                   IE::InferRequest request,
                   IE::StatusCode code) noexcept;
+};
+
+void AsyncInferExecutor::execute(const IInferExecutor::Task& task) {
+    using namespace std::placeholders;
+    using callback_t = std::function<void(IE::InferRequest, IE::StatusCode)>;
+    m_request.SetCompletionCallback(
+            static_cast<callback_t>(
+                std::bind(&AsyncInferExecutor::callback, this, task, _1, _2)));
+    try {
+        task.set_input_data(m_request);
+        m_request.StartAsync();
+    } catch (...) {
+        m_request.SetCompletionCallback([](){});
+        m_notify();
+        throw;
+    }
+}
+
+void AsyncInferExecutor::callback(IInferExecutor::Task task,
+                                  IE::InferRequest     request,
+                                  IE::StatusCode       code) noexcept {
+    task.read_output_data(request, code);
+    request.SetCompletionCallback([](){});
+    // NB: Notify pool that executor has finished.
+    m_notify();
+}
+
+class cv::gimpl::ie::RequestPool {
+public:
+
+    explicit RequestPool(std::vector<InferenceEngine::InferRequest>&& requests);
+
+    IInferExecutor::Ptr getIdleRequest();
+    void waitAll();
+
+private:
     void setup();
-    void releaseRequest(const int id);
+    void release(const int id);
 
-    QueueClass<size_t>                         m_idle_ids;
-    std::vector<InferenceEngine::InferRequest> m_requests;
-    bool                                       m_use_sync_api = false;
+    QueueClass<size_t>               m_idle_ids;
+    std::vector<IInferExecutor::Ptr> m_requests;
 };
 
-void cv::gimpl::ie::RequestPool::releaseRequest(const int id) {
-    if (!m_use_sync_api) {
-        auto& request = m_requests[id];
-        request.SetCompletionCallback([](){});
-    }
+void cv::gimpl::ie::RequestPool::release(const int id) {
     m_idle_ids.push(id);
 }
 
 // RequestPool implementation //////////////////////////////////////////////
-cv::gimpl::ie::RequestPool::RequestPool(std::vector<InferenceEngine::InferRequest>&& requests)
-    : m_requests(std::move(requests)) {
-        setup();
+cv::gimpl::ie::RequestPool::RequestPool(std::vector<InferenceEngine::InferRequest>&& requests) {
+    for (size_t i = 0; i < requests.size(); ++i) {
+        m_requests.emplace_back(
+                std::make_shared<AsyncInferExecutor>(std::move(requests[0]),
+                                                     std::bind(&RequestPool::release, this, i)));
     }
+    setup();
+}
 
 void cv::gimpl::ie::RequestPool::setup() {
     for (size_t i = 0; i < m_requests.size(); ++i) {
@@ -877,44 +944,10 @@ void cv::gimpl::ie::RequestPool::setup() {
     }
 }
 
-void cv::gimpl::ie::RequestPool::execute(cv::gimpl::ie::RequestPool::Task&& task) {
+IInferExecutor::Ptr cv::gimpl::ie::RequestPool::getIdleRequest() {
     size_t id = 0u;
     m_idle_ids.pop(id);
-    auto& request = m_requests[id];
-
-    try {
-        task.set_input_data(request);
-        if (m_use_sync_api) {
-            request.Infer();
-            task.read_output_data(request, IE::StatusCode::OK);
-            releaseRequest(id);
-        } else {
-            using namespace std::placeholders;
-            using callback_t = std::function<void(IE::InferRequest, IE::StatusCode)>;
-            request.SetCompletionCallback(
-                    static_cast<callback_t>(
-                        std::bind(&cv::gimpl::ie::RequestPool::callback, this,
-                                  task, id, _1, _2)));
-            request.StartAsync();
-        }
-    } catch (...) {
-        // NB: InferRequest is already marked as busy
-        // in case of exception need to return it back to the idle.
-        releaseRequest(id);
-        throw;
-    }
-}
-
-void cv::gimpl::ie::RequestPool::callback(cv::gimpl::ie::RequestPool::Task task,
-                                          size_t id,
-                                          IE::InferRequest request,
-                                          IE::StatusCode code) noexcept {
-    // NB: Inference is over.
-    // 1. Run callback
-    // 2. Destroy callback to free resources.
-    // 3. Mark InferRequest as idle.
-    task.read_output_data(request, code);
-    releaseRequest(id);
+    return m_requests[id];
 }
 
 // NB: Not thread-safe.
@@ -1330,8 +1363,8 @@ struct Infer: public cv::detail::KernelTag {
     static void run(std::shared_ptr<IECallContext>  ctx,
                     cv::gimpl::ie::RequestPool     &reqPool) {
         using namespace std::placeholders;
-        reqPool.execute(
-                cv::gimpl::ie::RequestPool::Task {
+        reqPool.getIdleRequest()->execute(
+                IInferExecutor::Task {
                     [ctx](InferenceEngine::InferRequest &req) {
                         // non-generic version for now:
                         // - assumes all inputs/outputs are always Mats
@@ -1440,8 +1473,8 @@ struct InferROI: public cv::detail::KernelTag {
     static void run(std::shared_ptr<IECallContext>  ctx,
                     cv::gimpl::ie::RequestPool     &reqPool) {
         using namespace std::placeholders;
-        reqPool.execute(
-                cv::gimpl::ie::RequestPool::Task {
+        reqPool.getIdleRequest()->execute(
+                IInferExecutor::Task {
                     [ctx](InferenceEngine::InferRequest &req) {
                         GAPI_Assert(ctx->uu.params.num_in == 1);
                         auto&& this_roi = ctx->inArg<cv::detail::OpaqueRef>(0).rref<cv::Rect>();
@@ -1579,8 +1612,8 @@ struct InferList: public cv::detail::KernelTag {
         for (auto&& it : ade::util::indexed(in_roi_vec)) {
                   auto  pos = ade::util::index(it);
             const auto& rc  = ade::util::value(it);
-            reqPool.execute(
-                cv::gimpl::ie::RequestPool::Task {
+            reqPool.getIdleRequest()->execute(
+                IInferExecutor::Task {
                     [ctx, rc, this_blob](InferenceEngine::InferRequest &req) {
                         setROIBlob(req, ctx->uu.params.input_names[0u], this_blob, rc, *ctx);
                     },
@@ -1734,8 +1767,8 @@ struct InferList2: public cv::detail::KernelTag {
 
         PostOutputsList callback(list_size, ctx, std::move(cached_dims));
         for (const auto &list_idx : ade::util::iota(list_size)) {
-            reqPool.execute(
-                cv::gimpl::ie::RequestPool::Task {
+            reqPool.getIdleRequest()->execute(
+                IInferExecutor::Task {
                     [ctx, list_idx, list_size, blob_0](InferenceEngine::InferRequest &req) {
                         for (auto in_idx : ade::util::iota(ctx->uu.params.num_in)) {
                             const auto &this_vec = ctx->inArg<cv::detail::VectorRef>(in_idx+1u);